forked from ssghost/vegans
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
128 lines (98 loc) · 4.16 KB
/
utils.py
File metadata and controls
128 lines (98 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import os
def batch_transform(batch, transform):
"""Applies a transform to a batch of samples.
Keyword arguments:
- batch (): a batch os samples
- transform (callable): A function/transform to apply to ``batch``
"""
# Convert the single channel label to RGB in tensor form
# 1. torch.unbind removes the 0-dimension of "labels" and returns a tuple of
# all slices along that dimension
# 2. the transform is applied to each slice
transf_slices = [transform(tensor) for tensor in torch.unbind(batch)]
return torch.stack(transf_slices)
def imshow_batch(images, labels):
"""Displays two grids of images. The top grid displays ``images``
and the bottom grid ``labels``
Keyword arguments:
- images (``Tensor``): a 4D mini-batch tensor of shape
(B, C, H, W)
- labels (``Tensor``): a 4D mini-batch tensor of shape
(B, C, H, W)
"""
# Make a grid with the images and labels and convert it to numpy
images = torchvision.utils.make_grid(images).numpy()
labels = torchvision.utils.make_grid(labels).numpy()
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 7))
ax1.imshow(np.transpose(images, (1, 2, 0)))
ax2.imshow(np.transpose(labels, (1, 2, 0)))
plt.show()
def save_checkpoint(model, optimizer, epoch, miou, args):
"""Saves the model in a specified directory with a specified name.save
Keyword arguments:
- model (``nn.Module``): The model to save.
- optimizer (``torch.optim``): The optimizer state to save.
- epoch (``int``): The current epoch for the model.
- miou (``float``): The mean IoU obtained by the model.
- args (``ArgumentParser``): An instance of ArgumentParser which contains
the arguments used to train ``model``. The arguments are written to a text
file in ``args.save_dir`` named "``args.name``_args.txt".
"""
name = args.name
save_dir = args.save_dir
assert os.path.isdir(
save_dir), "The directory \"{0}\" doesn't exist.".format(save_dir)
# Save model
model_path = os.path.join(save_dir, name)
checkpoint = {
'epoch': epoch,
'miou': miou,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
}
torch.save(checkpoint, model_path)
# Save arguments
summary_filename = os.path.join(save_dir, name + '_summary.txt')
with open(summary_filename, 'w') as summary_file:
sorted_args = sorted(vars(args))
summary_file.write("ARGUMENTS\n")
for arg in sorted_args:
arg_str = "{0}: {1}\n".format(arg, getattr(args, arg))
summary_file.write(arg_str)
summary_file.write("\nBEST VALIDATION\n")
summary_file.write("Epoch: {0}\n". format(epoch))
summary_file.write("Mean IoU: {0}\n". format(miou))
def load_checkpoint(model, optimizer, folder_dir, filename):
"""Saves the model in a specified directory with a specified name.save
Keyword arguments:
- model (``nn.Module``): The stored model state is copied to this model
instance.
- optimizer (``torch.optim``): The stored optimizer state is copied to this
optimizer instance.
- folder_dir (``string``): The path to the folder where the saved model
state is located.
- filename (``string``): The model filename.
Returns:
The epoch, mean IoU, ``model``, and ``optimizer`` loaded from the
checkpoint.
"""
assert os.path.isdir(
folder_dir), "The directory \"{0}\" doesn't exist.".format(folder_dir)
# Create folder to save model and information
model_path = os.path.join(folder_dir, filename)
assert os.path.isfile(
model_path), "The model file \"{0}\" doesn't exist.".format(filename)
# Load the stored model parameters to the model instance
checkpoint = torch.load(model_path)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
miou = checkpoint['miou']
return model, optimizer, epoch, miou