-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain.py
More file actions
55 lines (43 loc) · 1.48 KB
/
train.py
File metadata and controls
55 lines (43 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
from matplotlib import pyplot as plt
from torch_geometric.datasets import Planetoid
from core import get_model
from core.utils import GraphUtils
if __name__ == "__main__":
dataset = Planetoid(root="data/Cora", name="Cora")
# Get the data
data = dataset[0]
# Extract the adjacency matrix
adj_matrix = np.zeros((data.num_nodes, data.num_nodes))
edge_index = data.edge_index.numpy()
adj_matrix[edge_index[0], edge_index[1]] = 1
adj_matrix[edge_index[1], edge_index[0]] = 1
A = adj_matrix
X = data.x.numpy()
y = data.y.numpy()
# Get the number of unique labels
num_labels = len(np.unique(data.y))
# Convert labels to one-hot encoding
y = np.eye(num_labels)[y]
input_dim = X.shape[1]
hidden_dim = 16
output_dim = num_labels
epochs = 5
lr = 0.1
# Select model by name: "GCN", "GAT", "GIN", or "GraphSAGE"
MODEL_NAME = "GIN"
model = get_model(MODEL_NAME, input_dim, hidden_dim, output_dim)
loss_list = []
for epoch in range(epochs):
y_hat = model.forward(X, A)
loss = GraphUtils.loss_function(y, y_hat)
loss_list.append(loss)
print(f"the epoch {epoch+1}/{epochs} : \n The current Loss => {loss}")
model.backward(y, y_hat, lr=lr)
print("train finished")
plt.plot(range(epochs), loss_list)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.grid(True) # Add grid lines for better readability
plt.show()