-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_gc.py
More file actions
138 lines (127 loc) · 5.62 KB
/
eval_gc.py
File metadata and controls
138 lines (127 loc) · 5.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import random
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import torch
import torch.nn.functional as F
from torch.utils.data import random_split
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.loader import DataLoader
from models import GCN, GIN, GAT
from configs import get_arguments
from load_datasets import get_gc_dataset
from explainers.evaluate import eval_top_edges_drop, eval_top_edges_keep
args = get_arguments()
dataset_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data')
dataset_name = args.dataset
dataset = get_gc_dataset(dataset_path, dataset_name, self_loops=True)
num_train = int(.6 * len(dataset))
num_test = int(.2 * len(dataset))
num_eval = len(dataset) - num_train - num_test
train, eval, test = random_split(dataset, lengths=[num_train, num_eval, num_test],
generator=torch.Generator().manual_seed(0))
train_batches = DataLoader(train, batch_size=args.batch_size, shuffle=True)
eval_batches = DataLoader(eval, batch_size=num_eval)
test_batches = DataLoader(test, batch_size=num_test)
if args.model.lower() == 'gcn':
gnn = GCN(in_channels=dataset.num_node_features,
hidden_channels=args.hidden_channels,
num_layers=args.num_layers,
out_channels=dataset.num_classes,
dropout=args.dropout,
readout=args.readout,
add_self_loops=False)
elif args.model.lower() == 'gin':
gnn = GIN(in_channels=dataset.num_node_features,
hidden_channels=args.hidden_channels,
num_layers=args.num_layers,
out_channels=dataset.num_classes,
dropout=args.dropout,
readout=args.readout)
elif args.model.lower() == 'gat':
gnn = GAT(in_channels=dataset.num_node_features,
hidden_channels=args.hidden_channels,
num_layers=args.num_layers,
out_channels=dataset.num_classes,
dropout=args.dropout,
readout=args.readout,
add_self_loops=False,
heads=8)
else:
raise NotImplementedError('GNN not implemented!')
model_dir = './src'
if not os.path.exists(model_dir):
os.mkdir(model_dir)
model_name = dataset_name + '_' + args.model.lower() + '_l' + str(args.num_layers)
gnn.load_state_dict(torch.load(os.path.join(model_dir, model_name + '.pt')))
gnn.eval()
data = next(iter(test_batches))
device = torch.device('cpu')
gnn = gnn.to(device)
data = data.to(device)
out = gnn(data.x, data.edge_index, batch=data.batch)
prob = F.softmax(out, dim=-1)
pred = out.argmax(dim=-1)
correct = (pred == data.y.view(-1)).sum()
acc = torch.div(correct / len(data), 1e-4, rounding_mode='floor') * 1e-4
print(f'Accuracy: {acc * 100:.2f}')
res_dir = os.path.join('./res', model_name)
random.seed(2024)
graph_ids = list(range(len(dataset)))
random.shuffle(graph_ids)
candidates = args.candidates
if candidates is None or candidates > len(dataset):
candidates = len(dataset)
old_pred = []
new_pred_drop = []
new_pred_keep = []
AUC = []
for index in tqdm(graph_ids[:candidates]):
if args.explainer in ['revelio', 'gnn-lrp', 'flowx']:
if args.fidelity_plus and args.explainer != 'gnn-lrp':
flows = torch.load(os.path.join(res_dir, args.explainer + '_plus_' + str(index) + '.pt'),
map_location='cpu')
else:
flows = torch.load(os.path.join(res_dir, args.explainer + '_' + str(index) + '.pt'),
map_location='cpu')
edge_mask = flows['mask']
else:
if args.fidelity_plus and args.explainer not in ['deeplift', 'gradcam', 'pgmexplainer']:
edge_mask = torch.load(os.path.join(res_dir, args.explainer + '_plus_' + str(index) + '.pt'),
map_location='cpu')
else:
edge_mask = torch.load(os.path.join(res_dir, args.explainer + '_' + str(index) + '.pt'),
map_location='cpu')
data = dataset[index]
out = gnn(data.x, data.edge_index).view(-1)
prob = F.softmax(out, dim=-1)
pred = out.argmax(dim=-1)
pred_hat_drop = eval_top_edges_drop(gnn, data.x, data.edge_index, edge_mask, torch.arange(edge_mask.shape[0]),
pred, None, 0.5, 0.6, 0.7, 0.8, 0.9)
new_pred_drop.append(pred_hat_drop)
pred_hat_keep = eval_top_edges_keep(gnn, data.x, data.edge_index, edge_mask, torch.arange(edge_mask.shape[0]),
pred, None, 0.5, 0.6, 0.7, 0.8, 0.9)
new_pred_keep.append(pred_hat_keep)
old_pred.append(prob[pred])
if torch.isnan(edge_mask).sum() > 0:
continue
# for synthetic dataset
if dataset_name.lower() == 'ba_2motifs' and pred == data.y == 1:
loop_start = data.num_edges - data.num_nodes
ground_truth = dataset.gen_motif_edge_mask(data, args.num_layers)
y_true, y_pred = ground_truth[:loop_start], edge_mask[:loop_start]
if y_true.sum() == y_true.shape[0]:
continue
auc = roc_auc_score(y_true, y_pred)
AUC.append(auc)
old_pred = torch.tensor(old_pred, dtype=torch.float)
new_pred_drop = torch.tensor(new_pred_drop, dtype=torch.float, device=old_pred.device)
new_pred_keep = torch.tensor(new_pred_keep, dtype=torch.float, device=old_pred.device)
print('Fidelity-')
fidelity_pos = (old_pred.unsqueeze(1) - new_pred_keep).mean(dim=0)
print(fidelity_pos.tolist())
print('Fidelity+')
fidelity_neg = (old_pred.unsqueeze(1) - new_pred_drop).mean(dim=0)
print(fidelity_neg.tolist())
if dataset_name.lower() == 'ba_2motifs':
print('AUC:', sum(AUC) / len(AUC), ', %d instances are considered.' % (len(AUC)))