-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgnn.py
More file actions
123 lines (104 loc) · 4.11 KB
/
Copy pathgnn.py
File metadata and controls
123 lines (104 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, LayerNorm
class VRPEdgePredictor(nn.Module):
"""
Supervised Edge Co-Membership Graph Neural Network.
Predicts the probability that two nodes belong to the same cluster.
"""
def __init__(self, node_in_dim=5, edge_in_dim=2, hidden_dim=64, num_layers=4, dropout=0.2):
super().__init__()
# Initial Embeddings
self.node_embed = nn.Sequential(
nn.Linear(node_in_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Dropout(dropout),
nn.ReLU()
)
self.edge_embed = nn.Sequential(
nn.Linear(edge_in_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Dropout(dropout),
nn.ReLU()
)
# Message Passing Layers
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.dropouts = nn.ModuleList()
for _ in range(num_layers):
conv = GATv2Conv(
in_channels=hidden_dim,
out_channels=hidden_dim,
heads=4,
concat=False, # Averages the heads back to hidden_dim
edge_dim=hidden_dim,
add_self_loops=False
)
self.convs.append(conv)
self.norms.append(LayerNorm(hidden_dim))
self.dropouts.append(nn.Dropout(dropout))
# Edge Predictor Head
# Takes [node_u, node_v, orig_edge_feat] -> logits
self.edge_mlp = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Dropout(dropout),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.Dropout(dropout),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1) # Logits -> passed to BCEWithLogitsLoss
)
def forward(self, x, edge_index, edge_attr):
"""
x: [num_nodes, node_in_dim]
edge_index: [2, num_edges]
edge_attr: [num_edges, edge_in_dim]
"""
# 1. Initial projections
h = self.node_embed(x)
e = self.edge_embed(edge_attr)
# 2. Message Passing
for conv, norm, drop in zip(self.convs, self.norms, self.dropouts):
h_next = conv(h, edge_index, edge_attr=e)
# Residual connection with LayerNorm and Dropout
h = norm(h + drop(F.relu(h_next)))
# 3. Edge Prediction
src, dst = edge_index
h_src = h[src]
h_dst = h[dst]
# Combine src node, dst node, and the processed edge feature
edge_cat = torch.cat([h_src, h_dst, e], dim=-1)
# Output logits (Shape: [num_edges])
out = self.edge_mlp(edge_cat).squeeze(-1)
return out
if __name__ == "__main__":
from dataset import VRPLabelDataset
from pathlib import Path
print("\nInitializing Test Forward Pass...")
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
dataset = VRPLabelDataset(
graphs_dir=str(DATA_DIR / "xset_graphs"),
labels_dir=str(DATA_DIR / "nsga_labels"),
max_nodes=200,
param_filter="100_0p50_0p20",
seed_filter=None,
knn_k=20
)
if len(dataset) > 0:
sample = dataset[0]
model = VRPEdgePredictor(node_in_dim=5, edge_in_dim=2, hidden_dim=64, num_layers=4)
print("\nModel Architecture:")
print(f"Node embedding params: {sum(p.numel() for p in model.node_embed.parameters())}")
print(f"Num params overall: {sum(p.numel() for p in model.parameters())}")
print("\nRunning Forward Pass...")
# Execute model
model.eval()
with torch.no_grad():
logits = model(sample.x, sample.edge_index, sample.edge_attr)
print(f"Input nodes: {sample.x.shape}")
print(f"Input edges: {sample.edge_attr.shape}")
print(f"Output logits: {logits.shape}")
print(f"Test Successful!")