-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain-model.py
More file actions
95 lines (74 loc) · 2.67 KB
/
train-model.py
File metadata and controls
95 lines (74 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# NOTE: contain custom hydra resolvers
import hydra
import torch
import typing
import logging
import omegaconf
import pytorch_lightning
import pytorch_lightning.loggers
import src.resolvers
from src.auth import login_to_huggingface
from src.config import read_config, save_config
from src.models import MoLiNER, StartEndSegmentationModel
from src.data.utils.collator import SimpleBatchStructureCollator
from src.constants import (
DEFAULT_HYDRA_CONFIG_PATH,
DEFAULT_HYDRA_VERSION_BASE
)
# --- --- --- --- --- --- ---
login_to_huggingface()
# --- --- --- --- --- --- ---
logger = logging.getLogger(__name__)
# type: ignore
@hydra.main(config_path=DEFAULT_HYDRA_CONFIG_PATH, config_name="train-model", version_base=DEFAULT_HYDRA_VERSION_BASE)
def train_model(cfg: omegaconf.DictConfig):
logger.debug(f"[cfg]: {cfg}")
logger.info(f"[run_dir]: {cfg.run_dir}")
ckpt = None
if cfg.resume_dir is not None:
assert cfg.ckpt is not None
ckpt = cfg.ckpt
cfg = read_config(cfg.resume_dir)
logger.info("Resuming training")
logger.info(f"The config is loaded from: \n{cfg.resume_dir}")
else:
config_path = save_config(cfg)
logger.info("Training script")
logger.info(f"The config can be found here: \n{config_path}")
logger.info(f"[ckpt]: {ckpt}")
pytorch_lightning.seed_everything(cfg.seed)
logger.info("[data]: loading the dataloaders")
train_dataset = hydra.utils.instantiate(
cfg.data,
split="train"
)
validation_dataset = hydra.utils.instantiate(
cfg.data,
split="validation"
)
logger.info("[model]: loading the model")
model: MoLiNER | StartEndSegmentationModel = hydra.utils.instantiate(cfg.model)
train_dataloader: torch.utils.data.DataLoader = hydra.utils.instantiate(
cfg.dataloader,
dataset=train_dataset,
collate_fn=SimpleBatchStructureCollator(model.prompts_tokens_encoder if isinstance(model, MoLiNER) else None),
shuffle=True,
)
validation_dataloader: torch.utils.data.DataLoader = hydra.utils.instantiate(
cfg.dataloader,
dataset=validation_dataset,
collate_fn=SimpleBatchStructureCollator(model.prompts_tokens_encoder if isinstance(model, MoLiNER) else None),
shuffle=False,
)
trainer = hydra.utils.instantiate(cfg.trainer)
logger.info("[model]: loading motion encoder weights")
logger.info("[training]: started")
trainer.fit(
model,
train_dataloader,
validation_dataloader,
ckpt_path=ckpt
)
logger.info("[training]: completed")
if __name__ == "__main__":
train_model()