-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy path2D_train.py
More file actions
105 lines (77 loc) · 2.64 KB
/
2D_train.py
File metadata and controls
105 lines (77 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from argparse import ArgumentParser
from utils.get_data import get_2d_data
from utils.log_reg_utils import update_decision_boundary
class Abs(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.abs(x)
class NNet(nn.Module):
def __init__(self):
super().__init__()
# EDIT HERE
self.model = nn.Sequential(
nn.Linear(2, 3),
nn.Tanh(),
nn.Linear(3, 100),
nn.Tanh(),
nn.Linear(100, 3),
nn.Tanh(),
nn.Linear(3, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = self.model(x.float())
return x
def main():
parser: ArgumentParser = ArgumentParser()
parser.add_argument('--lr', type=float, default=0.01, help="Learning rate")
parser.add_argument('--dataset', type=str, default="half_moons", help="lin, xor, half_moons")
parser.add_argument('--steps', type=int, default=500, help="Number of training steps")
args = parser.parse_args()
torch.set_num_threads(1)
np.random.seed(42)
plt.ion()
x, y = get_2d_data(args.dataset)
fig, ax = plt.subplots(1, 1)
ax.set(adjustable='box')
ax.set_xticks([])
ax.set_yticks([])
ax.set_title("Decision Boundary")
mesh_size: int = 100
x1_vals_contour = np.linspace(np.min(x[:, 0]), np.max(x[:, 0]), mesh_size)
x2_vals_contour = np.linspace(np.min(x[:, 1]), np.max(x[:, 1]), mesh_size)
x1_mesh, x2_mesh = np.meshgrid(x1_vals_contour, x2_vals_contour)
mesh_points = np.stack((x1_mesh.reshape(mesh_size * mesh_size), x2_mesh.reshape(mesh_size * mesh_size)), axis=1)
nnet = NNet()
nnet.eval()
update_decision_boundary(ax, x, y, x1_mesh, x2_mesh, mesh_points, mesh_size, nnet)
plt.pause(0.5)
criterion = nn.BCELoss()
optimizer: Optimizer = optim.Adam(nnet.parameters(), lr=args.lr, weight_decay=0.0)
for i in range(args.steps):
# plot
nnet.eval()
update_decision_boundary(ax, x, y, x1_mesh, x2_mesh, mesh_points, mesh_size, nnet)
# forward (train)
nnet.train()
optimizer.zero_grad()
y_hat = nnet(torch.tensor(x))
# loss
loss = criterion(y_hat[:, 0], torch.tensor(y).float())
# backwards
loss.backward()
# step
optimizer.step()
print("Itrs: %i, Train: %.2E" % (i, loss.item()))
plt.pause(0.01)
plt.draw()
plt.show(block=True)
if __name__ == "__main__":
main()