-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_processing.py
More file actions
55 lines (41 loc) · 2.41 KB
/
Copy pathdata_processing.py
File metadata and controls
55 lines (41 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
from torch_geometric.utils import to_dense_adj
import pickle
# this is made assuming the datasets follow the structure of the TUDataset datasets
# i.e. a list where each graph is a data object format : Data(edge_index=[2,edges], x=[n,m], y=tensor([y]))
# where edges is the nr. of edges (each undirected edge counted twice), n is the nr. of nodes, m is the features of the nodes
# and y is the graph label (i.e. target)
def filter_small_graphs(dataset, max_nodes=150, min_nodes=6):
' input : dataset name (string), maximum number of nodes to keep (integer), minimum number of nodes to keep (integer) '
' output : a list of filtered graph data objects and a numpy array of targets '
# load data
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root=f'/tmp/{dataset}', name=dataset)
# filter out graphs with more than max_nodes nodes and less than k nodes (the max photon number, since we want to keep the same number of orbits for each graph)
filtered = []
targets = []
for graph in dataset:
if max_nodes >= graph.num_nodes >= min_nodes:
filtered.append(graph)
targets.append(int(graph.y))
targets = np.array(targets)
print(f"Filtered dataset: {len(filtered)} graphs with ≤{max_nodes} nodes and ≥{min_nodes} nodes.")
return filtered, targets
def edge_index_to_adj(filtered, undirected=True):
' input : list of filtered graph data objects, whether to treat edges as undirected (boolean) '
' output : a list of adjacency matrices '
adj = []
# the adjacency matrices from the to_dense_adj function are tensors, we convert to numpy arrays and remove the extra dimension (since to_dense_adj returns a batch of adjacency matrices, but we only have one graph at a time here)
if undirected:
for graph in filtered:
edge_index = graph.edge_index
adj_k = to_dense_adj(edge_index)[0].numpy() # convert to numpy array and remove batch dimension
adj.append(adj_k)
print(f"Converted {len(adj)} graphs to adjacency matrices.")
return adj
if __name__ == "__main__":
filtered, targets = filter_small_graphs(dataset='Proteins')
adj = edge_index_to_adj(filtered)
# save the processed data as a pickle file containing a dictionary with keys "adj" and "targets"
with open('data/proteins.pkl', 'wb') as f:
pickle.dump({'adj': adj, 'targets': targets}, f)