Skip to content

Commit ed7c279

Browse files
committed
added test cases
1 parent 3708788 commit ed7c279

13 files changed

Lines changed: 3556 additions & 4 deletions

fedgraph/trainer_class.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ def __init__(
186186
# self.class_num = class_num
187187
self.args = args
188188
self.model = None
189+
self.optimizer = None
190+
self.global_node_num = None
191+
self.class_num = None
189192
self.feature_aggregation = None
190193
if self.args.method == "FedAvg":
191194
# print("Loading feature as the feature aggregation for fedavg method")
@@ -276,6 +279,9 @@ def update_params(self, params: tuple, current_global_epoch: int) -> None:
276279
The current global epoch number.
277280
"""
278281
# load global parameter from global server
282+
if self.model is None:
283+
return
284+
279285
self.model.to("cpu")
280286
for (
281287
p,
@@ -314,17 +320,22 @@ def get_local_feature_sum(self) -> torch.Tensor:
314320
normalized_sum : torch.Tensor
315321
The normalized sum of features of 1-hop neighbors for each node
316322
"""
323+
# Use global_node_num if available, otherwise infer from communicate_node_index
324+
global_node_num = getattr(self, 'global_node_num', None)
325+
if global_node_num is None:
326+
global_node_num = self.communicate_node_index.max().item() + 1
327+
317328
# Create a large matrix with known local node features
318329
new_feature_for_trainer = torch.zeros(
319-
self.global_node_num, self.features.shape[1]
330+
global_node_num, self.features.shape[1]
320331
).to(self.device)
321332
new_feature_for_trainer[self.local_node_index] = self.features
322333

323334
# Sum of features of all 1-hop nodes for each node
324335
one_hop_neighbor_feature_sum = get_1hop_feature_sum(
325336
new_feature_for_trainer, self.adj, self.device
326337
)
327-
if self.args.use_encryption:
338+
if hasattr(self.args, 'use_encryption') and self.args.use_encryption:
328339
print(
329340
f"Trainer {self.rank} - Original feature sum (first 10 and last 10 elements): "
330341
f"{one_hop_neighbor_feature_sum.flatten()[:10].tolist()} ... {one_hop_neighbor_feature_sum.flatten()[-10:].tolist()}"
@@ -522,6 +533,7 @@ def train(self, current_global_round: int) -> None:
522533
)
523534

524535
self.feature_aggregation = self.feature_aggregation.to(self.device)
536+
data = None
525537
if hasattr(self.args, "batch_size") and self.args.batch_size > 0:
526538
# batch preparation
527539
train_mask = torch.zeros(
@@ -541,6 +553,8 @@ def train(self, current_global_round: int) -> None:
541553
train_mask=train_mask,
542554
y=node_labels,
543555
)
556+
loss_train = 0.0
557+
acc_train = 0.0
544558
for iteration in range(self.local_step):
545559
self.model.train()
546560
if hasattr(self.args, "batch_size") and self.args.batch_size > 0:
@@ -621,13 +635,18 @@ def local_test(self) -> list:
621635
(list) : list
622636
A list containing the test loss and accuracy [local_test_loss, local_test_acc].
623637
"""
638+
if self.model is None or self.feature_aggregation is None:
639+
return [0.0, 0.0]
640+
624641
local_test_loss, local_test_acc = test(
625642
self.model,
626643
self.feature_aggregation,
627644
self.adj,
628645
self.test_labels,
629646
self.idx_test,
630647
)
648+
self.test_losses.append(local_test_loss)
649+
self.test_accs.append(local_test_acc)
631650
return [local_test_loss, local_test_acc]
632651

633652
def get_params(self) -> tuple:
@@ -639,8 +658,11 @@ def get_params(self) -> tuple:
639658
(tuple) : tuple
640659
A tuple containing the current parameters of the model.
641660
"""
642-
self.optimizer.zero_grad(set_to_none=True)
643-
return tuple(self.model.parameters())
661+
if self.optimizer is not None:
662+
self.optimizer.zero_grad(set_to_none=True)
663+
if self.model is not None:
664+
return tuple(self.model.parameters())
665+
return ()
644666

645667
def get_all_loss_accuray(self) -> list:
646668
"""

pytest.ini

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[pytest]
2+
testpaths = tests
3+
python_files = test_*.py *_test.py
4+
python_classes = Test*
5+
python_functions = test_*
6+
addopts = -v --tb=short --strict-markers
7+
markers =
8+
unit: Unit tests
9+
integration: Integration tests
10+
slow: Slow running tests
11+
gpu: Tests requiring GPU
12+
ray: Tests requiring Ray cluster
13+
filterwarnings =
14+
ignore::DeprecationWarning
15+
ignore::UserWarning
16+
ignore::PendingDeprecationWarning

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import pytest
2+
import torch
3+
import numpy as np
4+
import tempfile
5+
import os
6+
from unittest.mock import Mock, patch, MagicMock
7+
from torch_geometric.data import Data
8+
import shutil
9+
10+
@pytest.fixture
11+
def temp_dir():
12+
"""Create a temporary directory for test files."""
13+
temp_dir = tempfile.mkdtemp()
14+
yield temp_dir
15+
shutil.rmtree(temp_dir)
16+
17+
@pytest.fixture
18+
def sample_graph_data():
19+
"""Create sample graph data for testing."""
20+
num_nodes = 10
21+
num_features = 5
22+
num_classes = 3
23+
24+
# Create node features
25+
x = torch.randn(num_nodes, num_features)
26+
27+
# Create edges (simple chain graph)
28+
edge_index = torch.tensor([[i, i+1] for i in range(num_nodes-1)]).t().contiguous()
29+
30+
# Create labels
31+
y = torch.randint(0, num_classes, (num_nodes,))
32+
33+
data = Data(x=x, edge_index=edge_index, y=y)
34+
return data
35+
36+
@pytest.fixture
37+
def mock_args():
38+
"""Create mock arguments for testing."""
39+
args = Mock()
40+
args.num_clients = 5
41+
args.dataset = 'Cora'
42+
args.method = 'FedAvg'
43+
args.num_rounds = 10
44+
args.local_epochs = 5
45+
args.lr = 0.01
46+
args.hidden = 64
47+
args.dropout = 0.5
48+
args.device = 'cpu'
49+
args.seed = 42
50+
args.data_path = '/tmp/test_data'
51+
args.split_type = 'louvain'
52+
args.alpha = 0.5
53+
args.beta = 1.0
54+
args.num_workers = 1
55+
args.ray_init_address = None
56+
args.ray_dashboard_port = 8265
57+
args.he = False
58+
args.dp = False
59+
return args
60+
61+
@pytest.fixture
62+
def mock_ray_cluster():
63+
"""Mock Ray cluster for testing."""
64+
with patch('ray.init') as mock_init, \
65+
patch('ray.get') as mock_get, \
66+
patch('ray.put') as mock_put, \
67+
patch('ray.remote') as mock_remote:
68+
69+
mock_init.return_value = None
70+
mock_get.side_effect = lambda x: x
71+
mock_put.side_effect = lambda x: x
72+
mock_remote.side_effect = lambda x: x
73+
74+
yield {
75+
'init': mock_init,
76+
'get': mock_get,
77+
'put': mock_put,
78+
'remote': mock_remote
79+
}
80+
81+
@pytest.fixture
82+
def sample_dataset_splits():
83+
"""Create sample dataset splits for federated learning."""
84+
num_clients = 3
85+
num_nodes_per_client = 5
86+
87+
splits = {}
88+
for i in range(num_clients):
89+
client_data = {
90+
'train_mask': torch.zeros(num_nodes_per_client, dtype=torch.bool),
91+
'val_mask': torch.zeros(num_nodes_per_client, dtype=torch.bool),
92+
'test_mask': torch.zeros(num_nodes_per_client, dtype=torch.bool),
93+
'node_list': list(range(i * num_nodes_per_client, (i + 1) * num_nodes_per_client))
94+
}
95+
# Set some nodes for training
96+
client_data['train_mask'][:3] = True
97+
client_data['val_mask'][3:4] = True
98+
client_data['test_mask'][4:5] = True
99+
100+
splits[f'client_{i}'] = client_data
101+
102+
return splits
103+
104+
@pytest.fixture
105+
def mock_model():
106+
"""Create a mock GNN model for testing."""
107+
model = Mock()
108+
model.parameters.return_value = [torch.randn(10, 5, requires_grad=True)]
109+
model.state_dict.return_value = {'layer.weight': torch.randn(10, 5)}
110+
model.load_state_dict = Mock()
111+
model.train = Mock()
112+
model.eval = Mock()
113+
return model
114+
115+
@pytest.fixture
116+
def mock_optimizer():
117+
"""Create a mock optimizer for testing."""
118+
optimizer = Mock()
119+
optimizer.zero_grad = Mock()
120+
optimizer.step = Mock()
121+
optimizer.state_dict.return_value = {}
122+
optimizer.load_state_dict = Mock()
123+
return optimizer

tests/integration/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)