-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
126 lines (111 loc) · 6.48 KB
/
utils.py
File metadata and controls
126 lines (111 loc) · 6.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch_geometric
from torch_sparse import SparseTensor
from torch_geometric.utils import dense_to_sparse
import random
import logging
from torch_geometric.data import InMemoryDataset
from torch_geometric.graphgym.config import cfg
class DataListSet(InMemoryDataset):
def __init__(self, datalist):
super().__init__()
self.data, self.slices = self.collate(datalist)
# TODO: be careful about the shape; perhaps correct for zinc, but not for all
def add_virtual_node_edge(data, format):
N = data.num_nodes
data.num_node_per_graph = torch.tensor((N+1), dtype=torch.long)
if format == 'PyG-ZINC':
data.x_original = data.x + 1 # TODO: note that this is different across different datasets, whether dimension or +1
data.edge_index_original = data.edge_index
data.edge_attr_original = data.edge_attr
# print(data.x, data.edge_attr)
A = SparseTensor(row=data.edge_index[0],
col=data.edge_index[1],
value=data.edge_attr, # TODO: note that this is different across different datasets, whether dimension or +1
sparse_sizes=(N+1, N+1)).coalesce().to_dense()
A[:, -1] = cfg.dataset.edge_encoder_num_types + 1
A[-1, :] = cfg.dataset.edge_encoder_num_types + 1
A[-1, -1] = 0
A += torch.diag_embed(torch.ones([N+1], dtype=torch.long) * (cfg.dataset.edge_encoder_num_types + 4))
edge_attr = A.reshape(-1, 1).long()
data.edge_attr = edge_attr
adj = torch.ones([N+1, N+1], dtype=torch.long)
edge_index = dense_to_sparse(adj)[0]
data.edge_index = edge_index
data.x = torch.cat([data.x + 1, torch.ones([1, 1], dtype=torch.long) * (cfg.dataset.node_encoder_num_types + 1)], dim=0)
elif format == 'OGB' and cfg.train.pretrain.atom_bond_only: # TODO: for pretrain, only use the atom and bond type
data.x_original = data.x[:, 0].unsqueeze(1) # TODO: note that this is different across different datasets, whether dimension or +1
data.edge_index_original = data.edge_index
data.edge_attr_original = data.edge_attr
# print(data.x, data.edge_attr)
A = SparseTensor(row=data.edge_index[0],
col=data.edge_index[1],
value=data.edge_attr[:, 0] + 1,
# TODO: note that this is different across different datasets, whether dimension or +1
sparse_sizes=(N + 1, N + 1)).coalesce().to_dense()
A[:, -1] = cfg.dataset.edge_encoder_num_types + 1
A[-1, :] = cfg.dataset.edge_encoder_num_types + 1
A[-1, -1] = 0
A += torch.diag_embed(torch.ones([N + 1], dtype=torch.long) * (cfg.dataset.edge_encoder_num_types + 4))
edge_attr = A.reshape(-1, 1).long()
data.edge_attr = edge_attr
adj = torch.ones([N + 1, N + 1], dtype=torch.long)
edge_index = dense_to_sparse(adj)[0]
data.edge_index = edge_index
data.x = torch.cat(
[data.x[:, 0].unsqueeze(1), torch.ones([1, 1], dtype=torch.long) * (cfg.dataset.node_encoder_num_types + 1)], dim=0)
elif format == 'OGB' and not cfg.train.pretrain.atom_bond_only: # TODO: for pretrain, only use the atom and bond type
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
data.x_original = data.x # TODO: note that this is different across different datasets, whether dimension or +1
data.edge_index_original = data.edge_index
data.edge_attr_original = data.edge_attr
# print(data.x, data.edge_attr)
A = SparseTensor(row=data.edge_index[0],
col=data.edge_index[1],
value=data.edge_attr + 1,
# TODO: note that this is different across different datasets, whether dimension or +1
sparse_sizes=(N + 1, N + 1)).coalesce().to_dense()
edge_virtual = torch.zeros([1, data.edge_attr.shape[1]], dtype=torch.long)
for i, dim in enumerate(get_atom_feature_dims()):
edge_virtual[0, i] = dim + 1
A[:, -1] = edge_virtual.repeat(N+1, 1)
A[-1, :] = edge_virtual.repeat(N+1, 1)
for j in range(N + 1):
A[j, j] = edge_virtual + 4
edge_attr = A.reshape(N+1, N+1, -1).long()
data.edge_attr = edge_attr
adj = torch.ones([N + 1, N + 1], dtype=torch.long)
edge_index = dense_to_sparse(adj)[0]
data.edge_index = edge_index
x_virtual = torch.zeros([1, data.x.shape[1]], dtype=torch.long)
for i, dim in enumerate(get_atom_feature_dims()):
x_virtual[0, i] = dim
data.x = torch.cat([data.x, x_virtual], dim=0)
elif format == 'PyG-QM9': # TODO: finish QM9
print('Already done in pretransform of QM9')
pass
else:
raise NotImplementedError
return data
# TODO: batchify large graphs, the code is as follows
# YourDataset(pre_transform=T.RootedEgoNets(hop=3))
# def random_mask(data, proportion_node=0.1, proportion_edge=0.1):
# N = data.num_nodes
# m = data.edge_index.shape[1]
# masked_node_idx = random.sample(range(N), int(N * proportion_node))
# masked_edge_idx = random.sample(range(m), int(m * proportion_edge))
# data.x_unmasked = data.x
# data.edge_attr_unmasked = data.edge_attr
# data.x[masked_node_idx] = torch.ones([int(N * proportion_node), 1], dtype=torch.long, device=data.x.device) * (cfg.dataset.node_encoder_num_types + 2)
# data.edge_attr[masked_edge_idx] = torch.ones([int(m * proportion_edge), 1], dtype=torch.long, device=data.x.device) * (cfg.dataset.edge_encoder_num_types + 1)
# return data, masked_node_idx, masked_edge_idx
def random_mask(data, proportion_node=0.1, proportion_edge=0.1):
N = data.num_nodes
m = data.edge_index.shape[1]
masked_node_idx = random.sample(range(N), int(N * proportion_node))
masked_edge_idx = random.sample(range(m), int(m * proportion_edge))
data.x_unmasked = data.x
data.edge_attr_unmasked = data.edge_attr
data.x[masked_node_idx] = torch.ones([int(N * proportion_node), 1], dtype=data.x.dtype, device=data.x.device) * (cfg.dataset.node_encoder_num_types + 2)
data.edge_attr[masked_edge_idx] = torch.ones([int(m * proportion_edge), 1], dtype=data.edge_attr.dtype, device=data.x.device) * (cfg.dataset.edge_encoder_num_types + 1)
return data, masked_node_idx, masked_edge_idx