-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
116 lines (94 loc) · 4.5 KB
/
main.py
File metadata and controls
116 lines (94 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import glob
import numpy as np
import torch
from data.faces import download_faces_data
from fastai.data.external import URLs, untar_data
from torch import nn, optim
from recolor.data_loaders import make_dataloaders
from recolor.models import MainModel, build_res_unet
from recolor.train import pretrain_generator, train_model
from recolor.utils import exists, str2bool
def check_opts(opts):
# exists(opts.data_path, "Data path not found!")
exists(opts.save_path, "Save path not found!")
assert opts.epochs > 0, "Epochs must be higher than 0"
def validate_data_type(dtype):
if dtype.lower() in ("face", "general"):
return dtype.lower()
else:
raise argparse.ArgumentTypeError("Invalid train type, supported are `face` and `general`")
def build_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--train-type", type=validate_data_type,
dest="type", help="What type of model to train",
metavar="TRAIN_TYPE", required=True)
parser.add_argument("--data-path", type=str,
dest="data_path", help="Path to data",
metavar="IN_PATH", required=False)
parser.add_argument("--save-path", type=str,
dest="save_path",
help="Path to save the model",
metavar="SAVE_PATH", required=True)
help_out = "If to pretrain the GAN before training the main model"
parser.add_argument("--pretrain", type=str2bool,
dest="pretrain", help=help_out, metavar="PRETRAIN",
required=True)
parser.add_argument("--epochs", type=int,
dest="epochs",
help="Number epochs to train the model for",
metavar="EPOCHS", required=True)
parser.add_argument("--use-gpu", type=str2bool,
dest="use_gpu",
help="Use a GPU if available",
metavar="USE_GPU", required=True)
return parser
if __name__ == "__main__":
parser = build_parser()
options = parser.parse_args()
check_opts(options)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if options.use_gpu is False:
device = "cpu"
paths = []
if options.type == "face":
download_faces_data()
paths.extend(glob.glob("~/datasets/celebahq/celeba_hq/train/female/*.jpg"))
paths.extend(glob.glob("~/datasets/celebahq/celeba_hq/train/male/*.jpg"))
paths.extend(glob.glob("~/datasets/celebahq/celeba_hq/val/female/*.jpg"))
paths.extend(glob.glob("~/datasets/celebahq/celeba_hq/val/male/*.jpg"))
print("Celebrity face dataset loaded!")
else:
coco_path = untar_data(URLs.COCO_SAMPLE)
coco_path = str(coco_path) + "/train_sample"
paths.extend(glob.glob(coco_path + "/*.jpg"))
print("COCO dataset loaded!")
np.random.seed(123)
# TODO Added agr instead of hard code
paths_subset = np.random.choice(paths, 100, replace=False) # choosing 1000 images randomly
rand_idxs = np.random.permutation(100)
train_idxs = rand_idxs[:80] # choosing the first 8000 as training set
val_idxs = rand_idxs[80:] # choosing last 2000 as validation set
train_paths = paths_subset[train_idxs]
val_paths = paths_subset[val_idxs]
print(len(train_paths), len(val_paths))
train_dl = make_dataloaders(paths=train_paths, split='train')
val_dl = make_dataloaders(paths=val_paths, split='val')
data = next(iter(train_dl))
Ls, abs_ = data['L'], data['ab']
print(Ls.shape, abs_.shape)
print(len(train_dl), len(val_dl))
if options.pretrain is False:
model = MainModel(device=device)
train_model(model, train_dl, val_dl, options.epochs)
else:
net_G = build_res_unet(n_input=1, n_output=2, size=256, device=device)
opt = optim.Adam(net_G.parameters(), lr=1e-4)
criterion = nn.L1Loss()
pretrain_generator(net_G, train_dl, opt, criterion, options.epochs, device=device)
torch.save(net_G.state_dict(), f"{options.save_path}/res18-unet.pt")
net_G = build_res_unet(n_input=1, n_output=2, size=256)
net_G.load_state_dict(torch.load(f"{options.save_path}/res18-unet.pt", map_location=device))
model = MainModel(net_G=net_G, device=device)
train_model(model, train_dl, val_dl, epochs=options.epochs)
torch.save(model.state_dict(), f"{options.save_path}/final_model_weights.pt")