-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
306 lines (264 loc) · 12.9 KB
/
train.py
File metadata and controls
306 lines (264 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_latency_hiding_scheduler=true '
)
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.97'
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer, AutoTokenizer
import torch
import datasets
import numpy as np
import jax
import jax.numpy as jnp
import optax
from torch2jax import t2j
import functools
from tqdm import tqdm
import operator
from dataclasses import dataclass, field
import draccus
import yaml
import peft
@dataclass
class PeftTrainConfig:
peft_config: peft.PeftStrategyConfig = field(default_factory=peft.LoraConfig)
model_name : str = 'google/gemma-2b'
# model_name : str = 'meta-llama/Llama-3.1-8B'
# model_name : str = 'google/gemma-7b'
# model_name : str = 'NousResearch/Llama-2-7b-hf'
# model_name : str = 'NousResearch/Llama-3.2-1B'
dataset_randomize : bool = False
seq_len : int = 463
out_dir : str = 'data'
epochs : int = 3
batchsize : int = 1
seed : int = 0
logging_steps : int = 250
peak_learning_rate : float = 3e-5
weight_decay : float = 0.001
def train_peft(cfg):
print(cfg)
model = AutoModelForCausalLM.from_pretrained(cfg.model_name, dtype='auto')
model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters in model : {model_params}")
peft_model = cfg.peft_config.wrap(model)
peft_params = sum(p.numel() for p in peft_model.parameters() if p.requires_grad)
print(f"Total trainable parameters in peft model : {peft_params} and are {(peft_params/model_params)*100} % of the original model")
dataset = get_dataset_gsm8k(cfg)
basename = yaml.dump(draccus.encode(cfg), default_flow_style=False, sort_keys=False)
outf = cfg.out_dir+'/'+hex(abs(hash(basename))).lstrip('0x')+'.out'
print(f'Writing results to {outf}')
outs = open(outf,'w')
outs.write(str({'model_params' : model_params, 'peft_params' : peft_params})+'\n')
# train(peft_model, dataset, OUT_DIR)
return train_jax(peft_model, dataset, cfg, outs)
class ModelWithLoss(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
# self.loss = torch.nn.CrossEntropyLoss(reduction="none")
# self.loss = torch.nn.CrossEntropyLoss(reduction="sum")
self.loss = torch.nn.CrossEntropyLoss()
def forward(self, *, input_ids=None, attention_mask=None ):
assert input_ids is not None
result = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = result['logits'][...,:-1,:]
result.loss = self.loss( logits.reshape(-1, logits.shape[-1]), input_ids[:,1:].reshape(-1) )
return (result.loss,)
# return result
def train(model, lm_dataset, output_dir):
# model = ModelWithLoss(model)
training_args = TrainingArguments(
output_dir=output_dir,
eval_strategy="epoch",
learning_rate=2.5e-6/4,
weight_decay=0.01,
# torch_compile=True,
save_strategy='no',
per_device_train_batch_size = 1,
per_device_eval_batch_size = 1,
save_safetensors = False, # work around bug
# bf16=True,
gradient_accumulation_steps=1,
logging_steps=500,
# eval_on_start=True
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_dataset["train"],
eval_dataset=lm_dataset["test"],
# data_collator=data_collator,
)
trainer.train()
return trainer
def train_jax(model_torch, lm_dataset, cfg, outs):
key = jax.random.key(cfg.seed)
state_dict = model_torch.state_dict(keep_vars=True)
trainable_keys = {}
nontrainable_keys = {}
for k,v in state_dict.items():
assert v.is_contiguous()
if v.requires_grad:
assert v.data_ptr() not in nontrainable_keys, "aliased parameters must agree whether they requires_grad"
trainable_keys.setdefault(v.data_ptr(), []).append(k)
else:
assert v.data_ptr() not in trainable_keys, "aliased parameters must agree whether they requires_grad"
nontrainable_keys.setdefault(v.data_ptr(), []).append(k)
trainable_state_dict = { tuple(ks) : t2j(state_dict[ks[0]]) for ks in trainable_keys.values() }
nontrainable_state_dict = { tuple(ks) : t2j(state_dict[ks[0]]) for ks in nontrainable_keys.values() }
model_jax = t2j(model_torch)
del model_torch, state_dict, trainable_keys, nontrainable_keys
def model(input_ids, trainable_state_dict, nontrainable_state_dict):
state_dict = dict({k: v for ks, v in trainable_state_dict.items() for k in ks},
**{k: v for ks, v in nontrainable_state_dict.items() for k in ks})
return model_jax(input_ids, state_dict = state_dict)
def get_batches(key, split='train'):
from itertools import batched
split = list(lm_dataset[split])
maxex = max([len(ex['input_ids']) for ex in split])
# bs = [1]
# bs = [1,2]
bs = [2]
# bs = [1,2,3,4]
split_by_b = { }
for ex in split:
try:
b = next(b for b in reversed(bs) if len(ex['input_ids']) <= maxex/b**0.9)
split_by_b.setdefault(b,[]).append(ex)
except StopIteration:
pass
print(f'{len(split_by_b)} groups of batches by example length')
print(sorted([(b,len(exs),max(len(ex['input_ids']) for ex in exs)) for b, exs in split_by_b.items()]))
batches = []
key, keycur = jax.random.split(key)
for (b, exs), keycur in zip(sorted(split_by_b.items()),jax.random.split(keycur,len(split_by_b))):
# seq_len = int(np.floor(maxex / np.sqrt(b)))
seq_len = max(len(ex['input_ids']) for ex in exs)
# b = max(b-1,1)
b *= cfg.batchsize
ixs = jax.random.permutation(keycur, len(exs))
for batch in batched(ixs, b):
batch = [exs[i] for i in batch]
input_ids = np.zeros((b, seq_len), dtype=jnp.int32)
id_mask = np.zeros((b, seq_len), dtype=jnp.bool)
for i,ex in enumerate(batch):
ids = ex['input_ids']
input_ids[i, :len(ids)] = ids
id_mask[i, :len(ids)] = True
batches.append((input_ids, id_mask))
ixs = jax.random.permutation(key,len(batches))
return [batches[i] for i in ixs]
@jax.jit
def loss_fn(trainable_state_dict,nontrainable_state_dict,input_ids,id_mask):
result = model(input_ids, trainable_state_dict, nontrainable_state_dict)
logits = result.logits[...,:-1,:]
id_mask = id_mask[...,1:]
labels = input_ids[...,1:]
cross_entropy = optax.losses.softmax_cross_entropy_with_integer_labels(
logits.reshape(-1, logits.shape[-1]), labels.reshape(-1),
)
return jnp.sum(jnp.where(id_mask.reshape(-1), cross_entropy, 0.0)), jnp.sum(id_mask)
@functools.partial(jax.jit,donate_argnums=[0,1])
def update_function(trainable_state_dict, opt_state, nontrainable_state_dict, input_ids, id_mask):
(loss, tokens), grads = jax.value_and_grad(loss_fn,has_aux=True)(
trainable_state_dict,nontrainable_state_dict,input_ids,id_mask)
grad_norm_square = jax.tree.map(lambda e: jnp.sum(e**2), grads)
# grad_norm_square = jnp.sqrt(jax.tree.reduce(operator.add, jax.tree.map(lambda e: jnp.sum(e**2), grads)))
updates, opt_state = optimizer.update(grads, opt_state, trainable_state_dict)
trainable_state_dict = optax.apply_updates(trainable_state_dict, updates)
return loss, tokens, grad_norm_square, trainable_state_dict, opt_state
test_batches = get_batches(jax.random.key(0), split='test')
def evaluate(trainable_state_dict, nontrainable_state_dict):
loss = 0.0
num = 0
for batch in tqdm(test_batches):
cur_loss, tokens = loss_fn(trainable_state_dict, nontrainable_state_dict, *batch)
loss += float(cur_loss)
num += tokens
return loss / num
# eval_loss = evaluate(trainable_state_dict, nontrainable_state_dict)
# print(f'eval_loss = {float(eval_loss)}')
batches = [batch for key in jax.random.split(key, cfg.epochs) for batch in get_batches(key)]
epoch_its = len(batches) // cfg.epochs
# schedule = optax.schedules.constant_schedule(cfg.peak_learning_rate)
# schedule = optax.schedules.warmup_cosine_decay_schedule(cfg.peak_learning_rate/10, cfg.peak_learning_rate,
# warmup_steps = len(batches) // 20, decay_steps=len(batches))
schedule = optax.schedules.cosine_decay_schedule(cfg.peak_learning_rate, decay_steps=len(batches))
# optimizer = optax.adam(schedule)
# optimizer = optax.adamw(schedule)
optimizer = optax.adamw(schedule,weight_decay=cfg.weight_decay)
opt_state = optimizer.init(trainable_state_dict)
min_train_loss = 100.
min_test_loss = 100.
losses = []
grad_norm_squares = []
tokenss = []
for it,batch in tqdm(enumerate(batches),total=len(batches)):
loss, tokens, grad_norm_square, trainable_state_dict, opt_state = update_function(
trainable_state_dict, opt_state, nontrainable_state_dict, *batch
)
losses.append(loss)
grad_norm_squares.append(grad_norm_square)
tokenss.append(tokens)
if it % cfg.logging_steps == cfg.logging_steps-1 or it == len(batches)-1:
cum_loss = 0.0
cum_grad_norm_square = None
ntokens = 0
for loss, grad_norm_square, tokens in zip(losses, grad_norm_squares, tokenss):
loss = float(loss)
grad_norm_square = jax.tree.map(lambda e: float(e), grad_norm_square)
tokens = int(tokens)
outs.write(str({'loss' : loss, 'grad_norm_square' : sum(grad_norm_square.values()), 'tokens' : tokens})+'\n')
cum_loss += loss
cum_grad_norm_square = grad_norm_square if cum_grad_norm_square is None else \
jax.tree.map(operator.add, cum_grad_norm_square, grad_norm_square)
ntokens += tokens
print({'loss' : cum_loss / ntokens,
'learning_rate' : float(schedule(it)),
'grad_norm' : float(np.sqrt(jax.tree.reduce(operator.add,cum_grad_norm_square) / ntokens)),
'epoch' : it / epoch_its })
# for k,v in jax.tree.map(lambda e: np.sqrt(e), cum_grad_norm_square).items():
# print(k,v)
if cum_loss / ntokens > 2.0 or not jnp.isfinite(cum_loss):
break
min_train_loss = min(min_train_loss, cum_loss / ntokens)
losses = []
grad_norm_squares = []
tokenss = []
if it % epoch_its == epoch_its-1 or it == len(batches)-1:
eval_loss = evaluate(trainable_state_dict, nontrainable_state_dict)
min_test_loss = min(min_test_loss, eval_loss)
print(f'eval_loss = {float(eval_loss)}')
outs.write(str({'eval_loss' : float(eval_loss)})+'\n')
outs.close()
return trainable_state_dict, min_train_loss, min_test_loss
def get_dataset_gsm8k(cfg):
data = datasets.load_dataset('openai/gsm8k', 'main')
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
if cfg.dataset_randomize:
data = datasets.concatenate_datasets([data['train'],data['test']])
data = data.train_test_split(test_size=0.15)
def preprocess(batch):
def format_question(ex):
return f"Q: {ex}\nA: "
def format_answer(ex):
# this is what helm does
answer_text = ex.replace("####", "The answer is").replace("\n", " ") + "."
return f"{answer_text}\n{tokenizer.eos_token}"
sources = [format_question(question) for question in batch['question']]
targets = [format_answer(answer) for answer in batch['answer']]
examples = [s + t for s, t in zip(sources, targets)]
sources_tokenized = tokenizer(sources, return_tensors="np", padding=False, truncation=True, max_length=cfg.seq_len)
examples_tokenized = tokenizer(examples, return_tensors="np", padding=False, truncation=True, max_length=cfg.seq_len)
# examples_tokenized = tokenizer(examples, return_tensors="np", padding='max_length', truncation=True, max_length=cfg.seq_len)
source_lens = [len(s) for s in sources_tokenized["input_ids"]]
return {
"input_ids": examples_tokenized["input_ids"],
"labels": examples_tokenized["input_ids"],
"source_lens": source_lens,
}
return data.map(preprocess, batched=True)
if __name__ == '__main__':
cfg = draccus.parse(config_class=PeftTrainConfig)
trainer = train_peft(cfg)