-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathrun.py
More file actions
executable file
·105 lines (89 loc) · 4.59 KB
/
Copy pathrun.py
File metadata and controls
executable file
·105 lines (89 loc) · 4.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import os
import warnings
import torch
import torch.multiprocessing as mp
from core.logger import VisualWriter, InfoLogger
import core.praser as Praser
import core.util as Util
from data import define_dataloader
from models import create_model, define_network, define_loss, define_metric
def main_worker(gpu, ngpus_per_node, opt):
""" threads running on each GPU """
if 'local_rank' not in opt:
opt['local_rank'] = opt['global_rank'] = gpu
if opt['distributed']:
torch.cuda.set_device(int(opt['local_rank']))
print('using GPU {} for training'.format(int(opt['local_rank'])))
torch.distributed.init_process_group(backend='nccl',
init_method=opt['init_method'],
world_size=opt['world_size'],
rank=opt['global_rank'],
group_name='mtorch'
)
'''set seed and and cuDNN environment '''
torch.backends.cudnn.enabled = False
warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True')
Util.set_seed(opt['seed'])
''' set logger '''
phase_logger = InfoLogger(opt)
phase_writer = VisualWriter(opt, phase_logger)
phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root']))
'''set networks and dataset'''
phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test.
networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']]
''' set metrics, loss, optimizer and schedulers '''
metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']]
losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']]
model = create_model(
opt=opt,
networks=networks,
phase_loader=phase_loader,
val_loader=val_loader,
losses=losses,
metrics=metrics,
logger=phase_logger,
writer=phase_writer
)
phase_logger.info('Begin model {}.'.format(opt['phase']))
if opt['phase'] == 'train':
model.train()
else:
model.test()
phase_writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config/EMDiffuse-n.json',
help='JSON file for configuration')
parser.add_argument('--path', type=str, default=None, help='patch of cropped patches')
parser.add_argument('-p', '--phase', type=str, choices=['train', 'test'], help='Run train or test', default='train')
parser.add_argument('-b', '--batch', type=int, default=None, help='Batch size in every gpu')
parser.add_argument('--gpu', type=str, default=None, help='the gpu devices used')
parser.add_argument('-d', '--debug', action='store_true')
parser.add_argument('-z', '--z_times', default=None, type=int, help='The anisotropy time of the volume em')
parser.add_argument('-P', '--port', default='21012', type=str)
parser.add_argument('--mean', type=int, default=2,
help='EMDiffuse samples one plausible solution from distribution. The number of samples you '
'want to generate and averaging')
parser.add_argument('--lr', type=float, default=5e-5, help='Learning rate')
parser.add_argument('--step', type=int, default=None, help='Steps of the diffusion process. More steps lead to '
'better image quality. ')
parser.add_argument('--resume', type=str, default=None,
help='Resume state path and load epoch number e.g., experiments/EMDiffuse-n/2720')
''' parser configs '''
args = parser.parse_args()
opt = Praser.parse(args)
''' cuda devices '''
gpu_str = ','.join(str(x) for x in opt['gpu_ids'])
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str
print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str))
''' use DistributedDataParallel(DDP) and multiprocessing for multi-gpu training'''
# [Todo]: multi GPU on multi machine
if opt['distributed']:
ngpus_per_node = len(opt['gpu_ids']) # or torch.cuda.device_count()
opt['world_size'] = ngpus_per_node
opt['init_method'] = 'tcp://127.0.0.1:' + args.port
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt))
else:
opt['world_size'] = 1
main_worker(0, 1, opt)