-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
120 lines (112 loc) · 4.09 KB
/
train.py
File metadata and controls
120 lines (112 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import torch
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy
from datetime import timedelta
from setproctitle import setproctitle
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
from model import DropNet
from datamodule import DropNetDataModule
def main(hparams):
wandb_logger = WandbLogger(project="test", name="NCF_DropNet", save_dir="./")
setproctitle("")
pl.seed_everything(hparams.seed)
ncf_datamodule = DropNetDataModule(hparams)
model = DropNet(hparams)
# TODO 딥러닝 임베딩을 불러올 요량이라면, 모델경로는 그냥 하드코딩으로 처리합니다.
model.load_state_dict(
torch.load(
"",
map_location="cuda",
),
strict=False,
)
wandb_logger.watch(model, log="all")
hparams.logger = wandb_logger
checkpoint_callback = ModelCheckpoint(
dirpath=hparams.output_dir,
save_top_k=3,
mode="max",
monitor="score",
filename="test-{epoch:02d}-{val_loss:.4f}",
)
lr_monitor = LearningRateMonitor(logging_interval="step")
hparams.callbacks = [checkpoint_callback, lr_monitor]
hparams.strategy = DDPStrategy(timeout=timedelta(days=30))
trainer = pl.Trainer.from_argparse_args(hparams)
trainer.fit(model, datamodule=ncf_datamodule)
checkpoint_callback.best_model_path
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser.add_argument("--seed", default=None, type=int, help="all seed")
parser.add_argument("--local_rank", type=int, help="ddp local rank")
parser.add_argument("--data_dir", type=str, help="target pytorch lightning data dirs")
parser.add_argument("--ratio", type=float, help="train/valid split ratio")
parser.add_argument("--output_dir", type=str, help="model output path")
parser.add_argument("--num_proc", type=int, default=None, help="how many proc map?")
parser.add_argument("--learning_rate", default=0.001, type=float, help="learning rate")
parser.add_argument(
"--warmup_ratio", default=0.2, type=float, help="learning rate scheduler warmup ratio per EPOCH"
)
parser.add_argument("--max_lr", default=0.01, type=float, help="lr_scheduler max learning rate")
parser.add_argument("--final_div_factor", default=1e4, type=int, help="(max_lr/25)*final_div_factor is final lr")
parser.add_argument("--weight_decay", default=0.0001, type=float, help="weigth decay")
parser.add_argument(
"--per_device_train_batch_size",
default=1,
type=int,
help="The batch size per GPU/TPU core/CPU for training.",
)
parser.add_argument(
"--per_device_eval_batch_size",
default=1,
type=int,
help="The batch size per GPU/TPU core/CPU for evaluation.",
)
parser.add_argument(
"--n_users",
default=200,
type=int,
help="The batch size per GPU/TPU core/CPU for evaluation.",
)
parser.add_argument(
"--n_items",
default=30000,
type=int,
help="The batch size per GPU/TPU core/CPU for evaluation.",
)
parser.add_argument(
"--dropout",
default=0.05,
type=float,
help="The batch size per GPU/TPU core/CPU for evaluation.",
)
parser.add_argument(
"--emb_dim",
default=256,
type=int,
help="The batch size per GPU/TPU core/CPU for evaluation.",
)
parser.add_argument(
"--layer_dim",
default=256,
type=int,
help="The batch size per GPU/TPU core/CPU for evaluation.",
)
parser.add_argument(
"--n_items_features",
default=300,
type=int,
help="The batch size per GPU/TPU core/CPU for evaluation.",
)
parser.add_argument(
"--n_users_features",
default=200,
type=int,
help="The batch size per GPU/TPU core/CPU for evaluation.",
)
args = parser.parse_args()
main(args)