-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
106 lines (75 loc) · 2.79 KB
/
utils.py
File metadata and controls
106 lines (75 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import torch
import torch.optim as optim
import matplotlib.pyplot as plt
def logtransform(x):
y = np.sign(x)*np.log(np.abs(x)+1)
return y
def loguntransform(x):
y = np.sign(x)*(np.exp(np.abs(x))-1)
return y
def train(model, noise_fn, loss_fn, train_ds, test_ds, iters=10000, lr=1e-4, batch_size=256):
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
model = model.to(device)
optimizerG = optim.Adam(model.parameters(), lr=lr, betas=(0.5, 0.9))
schedG = optim.lr_scheduler.ExponentialLR(optimizerG, 0.985, last_epoch=-1)
for iteration in range(iters):
model.zero_grad()
batch = np.random.randint(0, len(train_ds), size=batch_size)
real = train_ds[batch].to(device)
noise = noise_fn(batch_size)
noise = torch.Tensor(noise)
noise = noise.to(device)
fake = model(noise)
loss = loss_fn(fake, real)
loss.backward()
optimizerG.step()
# Write logs and save samples
if (iteration + 1) % 200 == 0:
print('Iteration', iteration + 1,
'Learning rate:', schedG.get_lr(),
'mmd:', loss.item())
schedG.step()
def compare(model, noise_fn, pdf_fun=None, tgt_data=None, model_path=None, output_transform=None, std=True, modelims=(-10,10),taillims=(-100,100)):
nsamples = 1000000
noise = noise_fn(nsamples)
model.cpu()
if model_path is not None:
model.load_state_dict(torch.load(model_path, map_location='cpu'))
# lc = quicklc(model)
# print('Lipschitz Constant:', lc)
fake = model(noise).detach().numpy()
if output_transform is not None:
fake = output_transform(fake)
else:
fake = fake
bins_mode = np.linspace(modelims[0], modelims[1], 100)
# fake = fake.cpu().detach().numpy()
# bins_tails = np.linspace(fake.min(), fake.max(), 100)
bins_tails = np.linspace(taillims[0], taillims[1], 100)
bins_mode_pdf = np.linspace(modelims[0], modelims[1], 100)
bins_tails_pdf = np.linspace(taillims[0], taillims[1], 100)
if pdf_fun is not None:
tgt_mode = pdf_fun(bins_mode_pdf).reshape(-1, 1)
tgt_tails = pdf_fun(bins_tails_pdf).reshape(-1, 1)
else:
tgt_mode = None
tgt_tails = None
plt.figure()
plt.hist(fake, bins=bins_mode, density=True)
plt.plot(bins_mode_pdf, tgt_mode)
# plt.legend(['Target distribution', 'Generated data'])
plt.xlabel('x')
plt.ylabel('Probaility Density')
plt.show()
plt.figure()
plt.hist(fake, bins=bins_tails, density=True)
plt.plot(bins_tails_pdf, tgt_tails)
# plt.legend(['Target distribution', 'Generated data'])
plt.xlabel('x')
plt.ylabel('Probaility Density')
plt.yscale('log')
plt.show()