-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
237 lines (206 loc) · 8.6 KB
/
train.py
File metadata and controls
237 lines (206 loc) · 8.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import os
import onnx.checker
from tqdm import tqdm
import logging
import itertools
import onnx
import onnxruntime
from IPython import embed
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from models import Policy
from util import SCRAMBLE_TYPE_TO_STATE_FUNC,get_parameter_number
from dataset import get_dataset
import loss
def get_config(cfg, args, key):
if hasattr(args, key) and (attr := getattr(args, key)):
return attr
elif key in cfg:
return cfg[key]
logging.warning(f'There is not setting for `{key}`')
return None
class Workspace:
def __init__(self, cfg, args):
self.cfg = cfg
self.args = args
# 1. use dataset to init dataloader
self.cube_type = self.get_config("cube_type")
dataset = get_dataset(self)
# TODO: support no data
if not dataset:
logging.error(f"There is no dataset for {self.get_config('dataset')}")
else:
self.dataloader = DataLoader(
dataset, shuffle=True,
num_workers=self.get_config("dataloader_workers") ,
batch_size=self.get_config("batch_size")
)
# make it infinite
self.dataloader = itertools.cycle(self.dataloader)
x, _ = next(self.dataloader)
input_size = x.shape[-1]
logging.info(f"use x shape: {x.shape}[-1] -> {input_size} as input size")
# 1.5 output dir
self.exp_dir = self.get_config("exp_dir")
if not os.path.exists(self.exp_dir):
os.makedirs(self.exp_dir)
logging.info(f"create {self.exp_dir}")
# 2. init hyper params & network
global device
device = self.get_config("device")
self.epoch = self.get_config("epoch")
self.batch_size = self.get_config("batch_size")
self.lr = self.get_config("lr")
self.network = Policy(
input_size,
1,
self.get_config("hidden_layer_size"),
self.get_config("hidden_depth"),
)
self.network.to(device)
logging.info(f"model params: {get_parameter_number(self.network)}")
self.optimizer = optim.Adam(list(self.network.parameters()), lr=self.lr)
loss_fn = self.get_config("loss")
self.loss_margin = self.get_config("loss_margin")
if not hasattr(loss, loss_fn):
logging.error(f"There is not impl for loss function: {loss_fn}")
exit(-2)
self.loss = getattr(loss, loss_fn)()
if self.get_config("pretrain_model"):
logging.info("trying to load pretrain model")
self.load_model()
# book keeping
self.plot_window = self.get_config("plot_window")
self.train_steps = [] # for plot
self.losses = []
# region-----------helper function---------------
def init_and_get_loss_margin(self):
if not hasattr(self, "loss_margin_ts"):
margin_ts = torch.tensor(self.loss_margin, dtype=torch.float32).to(device)
self.loss_margin_ts = margin_ts
return self.loss_margin_ts
def get_config(self, key):
return get_config(self.cfg, self.args, key)
# endregion------------------
# region-------------interface--------------
def train(self):
logging.info("begin to train...")
self.network.train()
with tqdm(range(self.epoch)) as tepoch:
for i in tepoch:
self.optimizer.zero_grad()
chosen_states_ts, reject_states_ts = next(self.dataloader)
chosen_states_ts = chosen_states_ts.to(device)
reject_states_ts = reject_states_ts.to(device)
# infer twice
chosen_rewards = self.network(chosen_states_ts)
reject_rewards = self.network(reject_states_ts)
# clac loss
epoch_loss = self.loss(
chosen_rewards,
reject_rewards,
self.init_and_get_loss_margin()
)
# embed() # import os; os._exit(0)
# step
epoch_loss.backward()
self.optimizer.step()
# plot
self.train_steps.append(i)
self.losses.append(epoch_loss.item())
if i % self.plot_window == 0:
plt.clf(); plt.cla()
# plt.yscale('log')
plt.plot(self.train_steps, self.losses, label="loss", color="blue")
plt.legend()
plt.savefig(f"{self.exp_dir}/train_loss.png")
tepoch.set_postfix(loss=f"{epoch_loss.item():.4f}")
self.save_model()
def infer(self):
if len(self.losses) == 0:
logging.warning(f"infer model may not being trained")
self.network.eval()
print("Enter infer loop, type `quit` or `exit` to quit the loop")
while True:
scramble = input("input scramble >").strip()
if scramble in {"quit", "exit"}:
print("quit infer loop")
break
try:
status = SCRAMBLE_TYPE_TO_STATE_FUNC[self.cube_type](scramble)
status = np.array(status)
status = torch.tensor(status, dtype=torch.float32).to(device)
reward = self.network(status)
print(reward)
except Exception as e:
print(e)
finally:
if locals().get("e", None):
print(e)
pass
# trans to onnx model
def to_onnx(self):
x = SCRAMBLE_TYPE_TO_STATE_FUNC[self.cube_type]("")
x = np.array(x)
x = torch.tensor(x, dtype=torch.float32)
# embed() # debug
x = x.unsqueeze_(0) # (batch_idx, state_dim)
x_np = x.numpy()
x = x.to(device)
y = self.network(x) # (batch_idx, 1)
logging.info(f"to onnx, input dim:{x.shape}, output dim: {y.shape}")
logging.info(f"model param nums: {get_parameter_number(self.network)}")
onnx_path = f"{self.exp_dir}/exp_onnx_{len(self.losses)}.onnx"
torch.onnx.export(
self.network,
x,
onnx_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "state_num"},
}
)
# onnx load test
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession(onnx_path)
ort_inputs = {ort_session.get_inputs()[0].name: x_np}
logging.info(f"onnx test input: {ort_inputs}")
import time
beg_time = time.time()
for _ in range(10000):
ort_outs = ort_session.run(None, ort_inputs)
logging.info(f"onnx test infer time: {time.time() - beg_time}")
np.testing.assert_allclose(y.detach().cpu().numpy(), ort_outs[0], rtol=1e-3, atol=1e-5)
pass
def save_model(self):
torch.save(self.network.state_dict(), f"{self.exp_dir}/model_{len(self.losses)}.bin")
def load_model(self):
if not self.get_config("pretrain_model"):
logging.warning("load model with not pretrain_model setting")
return
# try to load {data_dir}/model_path; model_path; {exp_dir}/model_path
def _load_model(path):
self.network.load_state_dict(torch.load(path))
model_path = self.get_config("pretrain_model")
data_dir = self.get_config("data_dir")
exp_dir = self.get_config("exp_dir")
if os.path.isfile(model_path):
logging.info(f"load model from {model_path}")
_load_model(model_path)
elif os.path.isfile(f"{data_dir}/{model_path}"):
logging.info(f"load model from {data_dir}/{model_path}")
_load_model(f"{data_dir}/{model_path}")
elif os.path.isfile(f"{exp_dir}/{model_path}"):
logging.info(f"load model from {exp_dir}/{model_path}")
_load_model(f"{exp_dir}/{model_path}")
else:
logging.warning(f"invalid model_path: {model_path}, check")
# endregion------------------