-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
129 lines (104 loc) · 3.88 KB
/
train.py
File metadata and controls
129 lines (104 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from model import CharacterLevelTokenizer, Config, PotterGPT
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from pathlib import Path
from tokenizers import Tokenizer
import matplotlib.pyplot as plt
torch.manual_seed(1357)
with open('data/harry_potter_data', 'r', encoding='utf-8') as f:
data = f.read()
class Dataset:
def __init__(self,Config, is_test=False) -> None:
self.tokenizer = CharacterLevelTokenizer(data)
self.is_test = is_test
self.full_data = self.tokenizer.encode(self.tokenizer.data)
if self.is_test:
self.data = self.full_data[int(0.9*len(self.full_data)):]
else:
self.data = self.full_data[:int(0.9*len(self.full_data))]
self.block_size = Config.block_size
self.batch_size = Config.batch_size
def __len__(self) -> int:
return len(self.data)
def get_block_size(self) -> int:
return self.block_size
def get_vocab_size(self) -> int:
return self.tokenizer.VOCAB_SIZE
def get(self):
ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
x = torch.stack([self.data[i:i+self.block_size] for i in ix])
y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
return x,y
# tokenizer = tokenizer = Tokenizer.from_file('tokenizer/potter.json')
tokenizer = CharacterLevelTokenizer(data)
#Training
train_ds = Dataset(Config)
val_ds = Dataset(Config, is_test=True)
lm = PotterGPT(Config)
lm = lm.to(device=Config.device)
optim = torch.optim.Adam(lm.parameters(), lr=Config.lr)
def loss_fn(logits, targets):
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return loss
def train_N_iters():
lm.train()
train_step_losses = []
for batch in tqdm(range(Config.train_iters)):
optim.zero_grad()
inputs, targets = train_ds.get()
inputs, targets = inputs.to(device=Config.device), targets.to(device=Config.device)
logits = lm(inputs)
loss = loss_fn(logits,targets)
loss.backward()
optim.step()
train_step_losses.append(loss.item())
if batch%(Config.train_iters//10)==0 or batch==Config.train_iters-1:
print(f"batch {batch} train step loss: {loss.item()}")
del inputs, targets, loss, logits
return train_step_losses
@torch.no_grad()
def valid_N_iters():
lm.eval()
val_step_losses = []
for batch in tqdm(range(Config.val_iters)):
inputs, targets = val_ds.get()
inputs, targets = inputs.to(device=Config.device), targets.to(device=Config.device)
logits = lm(inputs)
loss = loss_fn(logits,targets)
val_step_losses.append(loss.item())
if batch%(Config.val_iters//10)==0 or batch==Config.val_iters-1:
print(f"batch {batch} valid step loss: {loss.item()}")
del inputs, targets, loss, logits
return val_step_losses
def save_lm():
state_dict = lm.state_dict()
save_path = Path('./').resolve() / 'potterGPT'
save_path.mkdir(exist_ok=True)
model_path = save_path / f'potterGPT.pth'
torch.save(state_dict, model_path)
def train_lm():
train_losses = train_N_iters()
valid_losses = valid_N_iters()
save_lm()
return train_losses, valid_losses
tl, vl = train_lm()
plt.plot(tl,label='train loss',color='orange')
plt.plot(vl,label='valid loss',color='blue')
plt.title('Potter GPT Losses')
plt.legend()
plt.show()
generated_texts = []
for length in [100,300,500,700,1000]:
generated = lm.generate(
torch.zeros((1,1),dtype=torch.long,device=Config.device), # initial context 0
total=length
)
generated = tokenizer.decode(generated[0])
text=f'generated ({length} tokens)\n{"="*50}\n{generated}\n{"="*50}\n\n'
generated_texts.append(text)
print(text)