-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_category_md_from_split.py
More file actions
140 lines (112 loc) · 4.11 KB
/
generate_category_md_from_split.py
File metadata and controls
140 lines (112 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
"""按 iclr2026_paper_classification3.md 的分类,为每个类别生成单独的 Markdown 文件。"""
import glob
import json
import os
import re
from collections import defaultdict
ROOT = "/home/weijia/paper_analyze"
SPLIT_DIR = os.path.join(ROOT, "split_results_md")
OUT_DIR = os.path.join(ROOT, "category_split_results_md")
CATEGORY_ORDER = [
"智能体与多智能体系统",
"多模态LLM",
"LLM对齐与安全",
"强化学习",
"评测与基准",
"LLM训练与优化",
"LLM推理与效率",
"可解释性与分析",
"推理与数学",
"其他",
"检索增强与知识",
"具身智能与机器人",
"NLP应用",
]
def normalize_category(cat: str):
if cat == "代码生成与程序合成":
return None
if cat in {"智能体与工具使用", "多智能体系统"}:
return "智能体与多智能体系统"
return cat
def parse_split_file(path: str):
text = open(path, "r", encoding="utf-8").read()
matches = list(re.finditer(r"(?m)^##\s+\d+\.\s+", text))
items = []
for i, m in enumerate(matches):
start = m.start()
end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
block = text[start:end].strip()
title_match = re.search(r"^##\s+\d+\.\s+(.*)$", block, re.M)
id_match = re.search(r"^-\s+paper_id:\s+(.*)$", block, re.M)
if title_match and id_match:
items.append({
"title": title_match.group(1).strip(),
"paper_id": id_match.group(1).strip(),
"block": block,
})
return items
def renumber_block(block: str, idx: int):
block = re.sub(r"(?m)^##\s+\d+\.\s+", f"## {idx}. ", block, count=1)
return block.strip()
def safe_filename(name: str):
return re.sub(r'[\\/:*?"<>|]', '_', name)
def add_pdf_url(block: str, pdf_url: str):
if not pdf_url or "- pdf_url:" in block:
return block
if "- openreview_url:" in block:
return block.replace(
"- openreview_url:",
f"- pdf_url: {pdf_url}\n- openreview_url:",
1,
)
if "- keywords:" in block:
return re.sub(
r"(?m)^-\s+keywords:\s+(.*)$",
rf"- keywords: \1\n- pdf_url: {pdf_url}",
block,
count=1,
)
return block
def build_header(category: str, count: int):
return "# ICLR Oral Papers Analysis\n\n"
with open(os.path.join(ROOT, "paper_categories.json"), "r", encoding="utf-8") as f:
paper_categories = json.load(f)
with open(os.path.join(ROOT, "iclr26_all_papers.json"), "r", encoding="utf-8") as f:
all_papers = json.load(f)
id_to_pdf = {item["id"]: item.get("pdf_url", "") for item in all_papers}
cat_to_ids = defaultdict(list)
for item in paper_categories:
cat = normalize_category(item["category"])
if cat is not None:
cat_to_ids[cat].append(item["id"])
id_to_block = {}
for path in sorted(glob.glob(os.path.join(SPLIT_DIR, "analysis_papers_*.md"))):
for item in parse_split_file(path):
id_to_block[item["paper_id"]] = item
os.makedirs(OUT_DIR, exist_ok=True)
summary = []
for idx, category in enumerate(CATEGORY_ORDER, 1):
paper_ids = cat_to_ids.get(category, [])
blocks = []
missing = []
for i, pid in enumerate(paper_ids, 1):
item = id_to_block.get(pid)
if item is None:
missing.append(pid)
continue
block = add_pdf_url(item["block"], id_to_pdf.get(pid, ""))
blocks.append(renumber_block(block, i))
out_path = os.path.join(OUT_DIR, f"{idx:02d}_{safe_filename(category)}.md")
header = build_header(category, len(blocks))
body = "\n\n".join(blocks)
if body:
body += "\n"
with open(out_path, "w", encoding="utf-8") as f:
f.write(header + body)
summary.append((category, len(blocks), len(missing), out_path))
print(f"输出目录: {OUT_DIR}")
for category, found, missing, out_path in summary:
print(f"{category}: {found} 篇, missing={missing}, file={os.path.basename(out_path)}")
missing_total = sum(m for _, _, m, _ in summary)
print(f"总缺失: {missing_total}")