-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassify_papers.py
More file actions
252 lines (226 loc) · 8.99 KB
/
classify_papers.py
File metadata and controls
252 lines (226 loc) · 8.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
#!/usr/bin/env python3
"""
对ICLR2026论文进行研究领域分类
分类体系(15类):
1. LLM训练与优化
2. LLM推理与效率
3. LLM对齐与安全
4. 智能体与工具使用
5. 代码生成与程序合成
6. 多模态LLM
7. 推理与数学
8. 检索增强与知识
9. 强化学习
10. 具身智能与机器人
11. 评测与基准
12. 可解释性与分析
13. NLP应用
14. 多智能体系统
15. 其他
"""
import json
import re
from collections import defaultdict
# 分类规则:关键词映射到类别
CATEGORY_RULES = {
"LLM训练与优化": [
"预训练", "pre-training", "微调", "fine-tun", "指令微调", "instruction tun",
"参数高效", "peft", "lora", "知识蒸馏", "蒸馏", "distill",
"持续学习", "continual learn", "模型训练", "训练效率", "数据选择",
"合成数据", "synthetic data", "课程学习", "curriculum",
"大语言模型训练", "大语言模型预训练", "大语言模型微调",
],
"LLM推理与效率": [
"推理加速", "inference", "量化", "quantiz", "剪枝", "prun",
"推测解码", "speculative decod", "高效推理", "efficient inference",
"模型压缩", "compression", "缓存", "kv cache", "长上下文",
"大语言模型推理优化", "大语言模型推理加速", "推测解码",
],
"LLM对齐与安全": [
"对齐", "alignment", "rlhf", "偏好学习", "preference learn",
"安全", "safety", "越狱", "jailbreak", "有害", "harmful",
"对抗攻击", "adversarial attack", "鲁棒性", "robust",
"ai安全", "ai对齐", "模型安全", "大语言模型安全", "大语言模型对齐",
"模型对齐", "奖励模型", "reward model",
],
"智能体与工具使用": [
"智能体", "agent", "工具使用", "tool use", "工具调用",
"规划", "planning", "任务规划", "web agent", "gui agent",
"移动智能体", "mobile agent", "llm智能体", "大语言模型智能体",
"智能体系统", "智能体评估", "自主智能体", "autonomous agent",
],
"代码生成与程序合成": [
"代码生成", "code gen", "程序合成", "program synthesis",
"代码补全", "code complet", "软件工程", "software engineer",
"程序分析", "program analysis", "代码理解", "code understand",
"自动化编程", "automated program", "代码修复", "bug fix",
"代码搜索", "code search",
],
"多模态LLM": [
"多模态", "multimodal", "视觉语言", "vision language",
"图文", "image text", "视频理解", "video understand",
"视觉问答", "visual question", "图像生成", "image gen",
"多模态大语言模型", "视觉-语言", "视觉推理", "visual reason",
"多模态推理", "多模态学习", "多模态人工智能",
],
"推理与数学": [
"数学推理", "math reason", "逻辑推理", "logical reason",
"链式思考", "chain of thought", "cot", "推理能力",
"因果推理", "causal reason", "空间推理", "spatial reason",
"符号推理", "symbolic reason", "大语言模型推理",
"推理增强", "推理模型", "数学问题",
],
"检索增强与知识": [
"检索增强", "retrieval augment", "rag", "知识图谱", "knowledge graph",
"信息检索", "information retrieval", "知识库", "knowledge base",
"文档理解", "document understand", "问答", "question answer",
"知识编辑", "knowledge edit", "外部知识",
],
"强化学习": [
"强化学习", "reinforcement learn", "策略优化", "policy optim",
"奖励", "reward", "马尔可夫", "markov", "q-learning",
"actor-critic", "ppo", "dpo", "离线强化学习", "offline rl",
"在线强化学习", "online rl", "多智能体强化学习",
],
"具身智能与机器人": [
"具身智能", "embodied", "机器人", "robot", "机器人学",
"导航", "navigation", "操控", "manipulat", "行为基础模型",
"运动控制", "motion control", "仿真", "simulat",
],
"评测与基准": [
"基准测试", "benchmark", "评估", "evaluat", "数据集",
"dataset", "数据集构建", "数据集与基准", "评测",
"leaderboard", "排行榜", "测试集",
],
"可解释性与分析": [
"可解释", "interpret", "可解释人工智能", "xai",
"机制分析", "mechanistic", "神经元", "neuron",
"注意力分析", "attention analysis", "特征归因", "attribution",
"模型分析", "model analysis", "内部表示", "representation",
"模型可解释性",
],
"NLP应用": [
"自然语言处理", "nlp", "对话系统", "dialogue", "对话",
"机器翻译", "translation", "文本摘要", "summariz",
"情感分析", "sentiment", "文本分类", "text classif",
"命名实体", "ner", "信息抽取", "information extract",
"文本生成", "text gen", "语言生成",
],
"多智能体系统": [
"多智能体", "multi-agent", "多agent", "协作", "collaborat",
"通信", "communicat", "涌现", "emergent", "群体智能",
"swarm", "多智能体系统",
],
}
# 优先级顺序(越靠前优先级越高,避免重叠)
CATEGORY_PRIORITY = [
"代码生成与程序合成",
"具身智能与机器人",
"多智能体系统",
"智能体与工具使用",
"多模态LLM",
"推理与数学",
"检索增强与知识",
"LLM对齐与安全",
"LLM推理与效率",
"LLM训练与优化",
"强化学习",
"可解释性与分析",
"评测与基准",
"NLP应用",
"其他",
]
def classify_paper(paper):
"""根据论文的 research_fields、keywords、title 进行分类"""
# 提取文本特征
texts = []
texts.append(paper.get("title", "").lower())
texts.append(paper.get("keywords", "").lower())
conclusion_raw = paper.get("conclusion", "")
if isinstance(conclusion_raw, str):
# 尝试解析 JSON
try:
# 清理可能的 markdown 代码块
clean = re.sub(r"```json\s*|```\s*", "", conclusion_raw).strip()
conclusion = json.loads(clean)
except Exception:
conclusion = {}
elif isinstance(conclusion_raw, dict):
conclusion = conclusion_raw
else:
conclusion = {}
# 从 conclusion 中提取 research_fields 和 task_type
fields = conclusion.get("research_fields", [])
if isinstance(fields, list):
texts.extend([f.lower() for f in fields])
elif isinstance(fields, str):
texts.append(fields.lower())
task_type = conclusion.get("task_type", "")
if isinstance(task_type, list):
texts.extend([t.lower() for t in task_type])
elif isinstance(task_type, str):
texts.append(task_type.lower())
combined = " ".join(texts)
# 按优先级匹配
scores = defaultdict(int)
for category, keywords in CATEGORY_RULES.items():
for kw in keywords:
if kw.lower() in combined:
scores[category] += 1
if not scores:
return "其他"
# 按优先级选择得分最高的类别
best_category = "其他"
best_score = 0
for cat in CATEGORY_PRIORITY:
if cat == "其他":
continue
s = scores.get(cat, 0)
if s > best_score:
best_score = s
best_category = cat
return best_category
def main():
with open("extracted_data.json", "r", encoding="utf-8") as f:
papers = json.load(f)
print(f"总论文数: {len(papers)}")
results = []
category_count = defaultdict(int)
for paper in papers:
category = classify_paper(paper)
category_count[category] += 1
results.append({
"id": paper.get("id", ""),
"number": paper.get("number", ""),
"title": paper.get("title", ""),
"category": category,
"keywords": paper.get("keywords", ""),
})
# 输出统计
total = len(papers)
print("\n=== 研究领域分类统计 ===")
for cat in CATEGORY_PRIORITY:
count = category_count.get(cat, 0)
pct = count / total * 100
print(f" {cat:<20} {count:>5} 篇 ({pct:.1f}%)")
# 保存分类结果 JSON
with open("paper_categories2.json", "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print("\n已保存: paper_categories.json")
# 保存统计数据
stats = {
"total": total,
"categories": {
cat: {
"count": category_count.get(cat, 0),
"percentage": round(category_count.get(cat, 0) / total * 100, 2)
}
for cat in CATEGORY_PRIORITY
}
}
with open("category_stats.json", "w", encoding="utf-8") as f:
json.dump(stats, f, ensure_ascii=False, indent=2)
print("已保存: category_stats.json")
return results, category_count
if __name__ == "__main__":
main()