forked from WonbinKweon/TopicK_EMNLP2025
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocess.py
More file actions
254 lines (199 loc) · 6.38 KB
/
preprocess.py
File metadata and controls
254 lines (199 loc) · 6.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
preprocess.py
이 파일의 역할:
1. 데이터 로드
2. 문장 / 토픽 임베딩 생성
3. 토픽 분류용 MLP 로드
4. query → topic score 계산
5. 결과를 pickle로 저장
⚠️ GPU 전용 코드였던 부분을 CPU/GPU 자동 분기되도록 수정함
"""
# =========================================================
# 기본 라이브러리 / 유틸
# =========================================================
import os
import json
import pickle
import math
import re
from pprint import pprint
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data
from tqdm import tqdm
# HuggingFace / NLP
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import MultiLabelBinarizer
# openicl 구성요소 (retriever / evaluator 등)
from openicl import (
PromptTemplate,
DatasetReader,
RandomRetriever,
BM25Retriever,
ConERetriever,
TopkRetriever,
PPLInferencer,
AccEvaluator,
DPPRetriever,
MDLRetriever,
)
from utils import (
templates,
input_columns,
output_columns,
test_split,
score_mat_2_rank_mat,
omit_substrings,
)
# =========================================================
# device 설정 (CPU / GPU 자동 분기)
# =========================================================
# → CUDA 있으면 GPU 사용
# → Mac / CPU 환경이면 자동 CPU
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print("Using device:", device)
# CUDA 전용 옵션은 GPU 있을 때만 활성화
if torch.cuda.is_available():
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
# =========================================================
# 데이터 로딩
# =========================================================
print("loading dataset")
task = "cms"
task_name = task
data_dir = "data/"
train_path = data_dir + task_name + "/train.jsonl"
test_name = test_split[task_name]
test_path = data_dir + task_name + "/" + test_name + ".jsonl"
# HuggingFace datasets 로드
combined_dataset = load_dataset(
"json",
data_files={
"train": train_path,
"test": test_path
}
)
train_dataset = combined_dataset["train"]
test_dataset = combined_dataset["test"]
# query → topic 매핑 정보
with open(data_dir + task_name + "/qid2tid_dic", "rb") as f:
qid2tid_dic = pickle.load(f)
# 전체 topic 목록
with open(data_dir + task_name + "/topic_list", "rb") as f:
topic_list = pickle.load(f)
# =========================================================
# Sentence Embedding 생성
# =========================================================
print("computing query/topic embeddings")
# 문장 임베딩 모델
model_id = "sentence-transformers/all-mpnet-base-v2"
model = SentenceTransformer(model_id)
model = model.to(device)
model.eval()
from torch.utils.data import DataLoader
# -------------------------
# Query embedding
# -------------------------
query_loader = DataLoader(train_dataset["text"], batch_size=1024)
query_embeddings = []
for batch in tqdm(query_loader):
with torch.no_grad():
emb = model.encode(batch)
query_embeddings.extend(emb)
query_emb = np.array(query_embeddings)
# -------------------------
# Topic embedding
# -------------------------
topic_loader = DataLoader(topic_list, batch_size=1024)
topic_embeddings = []
for batch in tqdm(topic_loader):
with torch.no_grad():
emb = model.encode(batch)
topic_embeddings.extend(emb)
topic_emb = np.array(topic_embeddings)
# -------------------------
# 저장
# -------------------------
with open(data_dir + task_name + "/query_emb", "wb") as fw:
pickle.dump(query_emb, fw, protocol=pickle.HIGHEST_PROTOCOL)
with open(data_dir + task_name + "/topic_emb", "wb") as fw:
pickle.dump(topic_emb, fw, protocol=pickle.HIGHEST_PROTOCOL)
# =========================================================
# Topic Predictor (MLP)
# =========================================================
"""
query embedding → topic embedding과의 점수 계산 모델
구조:
- 입력: query embedding (768)
- MLP (3-layer)
- topic embedding과 내적 → topic score
"""
class Topic_predictor(nn.Module):
def __init__(self, topic_emb):
super().__init__()
# 학습하지 않는 고정 topic embedding
self.topic_emb = nn.Parameter(topic_emb, requires_grad=False)
# 간단한 MLP projection
self.mlp = nn.Sequential(
nn.Linear(768, 768),
nn.ReLU(),
nn.Linear(768, 768),
nn.ReLU(),
nn.Linear(768, 768),
)
def forward(self, batch_X):
# (B, 768) → (B, num_topic)
return torch.mm(self.mlp(batch_X), self.topic_emb.T)
# 모델 생성
CLF = Topic_predictor(torch.FloatTensor(topic_emb)).to(device)
# 사전 학습된 weight 로드 (GPU/CPU 안전)
clf_path = data_dir + task_name + "/topic_predictor"
CLF.load_state_dict(
torch.load(clf_path, map_location=device, weights_only=True)
)
# =========================================================
# Dataset wrapper (추론용)
# =========================================================
class CLF_dataset(data.Dataset):
def __init__(self, X, Y):
self.X = X
self.Y = Y
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
return idx, self.X[idx]
def get_labels(self, batch_indices):
return self.Y[batch_indices]
train_X = torch.FloatTensor(query_emb)
train_Y = None
clf_dataset = CLF_dataset(train_X, train_Y)
clf_loader = data.DataLoader(
clf_dataset,
batch_size=1024,
shuffle=False,
)
# =========================================================
# Topic score 추론
# =========================================================
with torch.no_grad():
CLF.eval()
all_logits = []
for _, batch in enumerate(clf_loader):
batch_indices, batch_X = batch
batch_X = batch_X.to(device)
output = CLF(batch_X)
all_logits.extend(output.cpu())
c_clf_logit = torch.stack(all_logits)
# 결과 저장
with open(data_dir + task_name + "/query_clf_logit", "wb") as fw:
pickle.dump(c_clf_logit, fw, protocol=pickle.HIGHEST_PROTOCOL)
print("✅ preprocess finished successfully.")