-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsave_predictions.py
More file actions
executable file
·118 lines (90 loc) · 4.09 KB
/
save_predictions.py
File metadata and controls
executable file
·118 lines (90 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from utils.functions import *
from utils.metrics import IoUMetric
from helper import load_datasetloader, load_solvers
from NuscenesDataset.common import CLASSES
from NuscenesDataset.save_pred import BaseSave
import matplotlib.pyplot as plt
import dill
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--exp_id', type=int, default=624)
parser.add_argument('--gpu_num', type=int, default=0)
parser.add_argument('--dataset_type', type=str, default='nuscenes')
parser.add_argument('--model_name', type=str, default='Scratch')
parser.add_argument('--is_test_all', type=int, default=1)
parser.add_argument('--target', type=str, default='lane')
parser.add_argument('--model_num', type=int, default=24)
parser.add_argument('--save_original_images', type=int, default=0)
parser.add_argument('--threshold', type=float, default=0.4)
args = parser.parse_args()
# logging setting
folder_name = args.dataset_type + '_' + args.model_name + '_model' + str(args.exp_id)
save_dir = os.path.join('./saved_models/', folder_name)
logging.basicConfig(
filename=save_dir + '/pred_save.log',
filemode="w",
format='%(asctime)s %(levelname)s:%(message)s',
level=logging.INFO,
datefmt='%m/%d/%Y %I:%M:%S %p',
)
logger = logging.getLogger(__name__)
consoleHandler = logging.StreamHandler(stream=sys.stdout)
consoleHandler.setLevel(level=logging.DEBUG)
logger.addHandler(consoleHandler)
# run train.py
try: test(args, logger)
except: logging.error(traceback.format_exc())
def test(args, logger):
# CUDA setting
os.environ["CUDA_VISIBLE_DEVICES"] = str(int(args.gpu_num))
# type definition
_, float_dtype = get_dtypes(useGPU=True)
# path to saved network
folder_name = args.dataset_type + '_' + args.model_name + '_model' + str(args.exp_id)
path = os.path.join('./saved_models/', folder_name)
# load parameter setting
with open(os.path.join(path, 'config.pkl'), 'rb') as f:
saved_args = pickle.load(f)
saved_args.ddp = 0
saved_args.bool_mixed_precision = 0
saved_args.save_dir = path
saved_args.exp_id = args.exp_id
# save folder
folder_name = args.model_name + '_exp' + str(args.exp_id) + '_m' + str(args.model_num)
save_dir = os.path.join('./VisResults/', folder_name)
if save_dir != '' and not os.path.exists(save_dir):
os.makedirs(save_dir)
print_training_info(saved_args, logger)
logger.info(f">> Test target : {args.target}")
# load test data
dataset, data_loader, _ = load_datasetloader(args=saved_args, dtype=torch.FloatTensor,
world_size=1, rank=0, mode='test')
# define network
solver = load_solvers(saved_args, dataset.num_scenes, world_size=1, rank=0, logger=logger,
dtype=float_dtype, isTrain=False)
save = BaseSave(label_indices=solver.cfg['label_indices'], SEMANTICS=CLASSES, threshold=args.threshold)
# load pretrained network
solver.load_pretrained_network_params(args.model_num)
solver.mode_selection(isTrain=False)
for b, batch in enumerate(tqdm(data_loader, desc='Test')):
if (b > 500): continue
# inference
with torch.no_grad():
pred = solver.model(batch, float_dtype, rank=0)
# save
if (args.save_original_images == 1):
img = save.return_cams(batch)
file_name_cam = 'cam_%04d.ckpl' % b
file_path = os.path.join(save_dir, file_name_cam)
with open(file_path, 'wb') as f:
dill.dump(img, f, protocol=dill.HIGHEST_PROTOCOL)
gt_bev = save.return_gt(batch, args.target)
pred_bev = save.return_pred(pred, args.target)
aux_bev = save.return_pred(pred, 'intm')
data = {'gt': gt_bev, 'pred': pred_bev, 'aux': aux_bev}
file_name_bev = args.target + '_%04d.ckpl' % b
file_path = os.path.join(save_dir, file_name_bev)
with open(file_path, 'wb') as f:
dill.dump(data, f, protocol=dill.HIGHEST_PROTOCOL)
if __name__ == '__main__':
main()