@@ -186,6 +186,9 @@ def __init__(
186186 # self.class_num = class_num
187187 self .args = args
188188 self .model = None
189+ self .optimizer = None
190+ self .global_node_num = None
191+ self .class_num = None
189192 self .feature_aggregation = None
190193 if self .args .method == "FedAvg" :
191194 # print("Loading feature as the feature aggregation for fedavg method")
@@ -276,6 +279,9 @@ def update_params(self, params: tuple, current_global_epoch: int) -> None:
276279 The current global epoch number.
277280 """
278281 # load global parameter from global server
282+ if self .model is None :
283+ return
284+
279285 self .model .to ("cpu" )
280286 for (
281287 p ,
@@ -314,17 +320,22 @@ def get_local_feature_sum(self) -> torch.Tensor:
314320 normalized_sum : torch.Tensor
315321 The normalized sum of features of 1-hop neighbors for each node
316322 """
323+ # Use global_node_num if available, otherwise infer from communicate_node_index
324+ global_node_num = getattr (self , 'global_node_num' , None )
325+ if global_node_num is None :
326+ global_node_num = self .communicate_node_index .max ().item () + 1
327+
317328 # Create a large matrix with known local node features
318329 new_feature_for_trainer = torch .zeros (
319- self . global_node_num , self .features .shape [1 ]
330+ global_node_num , self .features .shape [1 ]
320331 ).to (self .device )
321332 new_feature_for_trainer [self .local_node_index ] = self .features
322333
323334 # Sum of features of all 1-hop nodes for each node
324335 one_hop_neighbor_feature_sum = get_1hop_feature_sum (
325336 new_feature_for_trainer , self .adj , self .device
326337 )
327- if self .args .use_encryption :
338+ if hasattr ( self . args , 'use_encryption' ) and self .args .use_encryption :
328339 print (
329340 f"Trainer { self .rank } - Original feature sum (first 10 and last 10 elements): "
330341 f"{ one_hop_neighbor_feature_sum .flatten ()[:10 ].tolist ()} ... { one_hop_neighbor_feature_sum .flatten ()[- 10 :].tolist ()} "
@@ -522,6 +533,7 @@ def train(self, current_global_round: int) -> None:
522533 )
523534
524535 self .feature_aggregation = self .feature_aggregation .to (self .device )
536+ data = None
525537 if hasattr (self .args , "batch_size" ) and self .args .batch_size > 0 :
526538 # batch preparation
527539 train_mask = torch .zeros (
@@ -541,6 +553,8 @@ def train(self, current_global_round: int) -> None:
541553 train_mask = train_mask ,
542554 y = node_labels ,
543555 )
556+ loss_train = 0.0
557+ acc_train = 0.0
544558 for iteration in range (self .local_step ):
545559 self .model .train ()
546560 if hasattr (self .args , "batch_size" ) and self .args .batch_size > 0 :
@@ -621,13 +635,18 @@ def local_test(self) -> list:
621635 (list) : list
622636 A list containing the test loss and accuracy [local_test_loss, local_test_acc].
623637 """
638+ if self .model is None or self .feature_aggregation is None :
639+ return [0.0 , 0.0 ]
640+
624641 local_test_loss , local_test_acc = test (
625642 self .model ,
626643 self .feature_aggregation ,
627644 self .adj ,
628645 self .test_labels ,
629646 self .idx_test ,
630647 )
648+ self .test_losses .append (local_test_loss )
649+ self .test_accs .append (local_test_acc )
631650 return [local_test_loss , local_test_acc ]
632651
633652 def get_params (self ) -> tuple :
@@ -639,8 +658,11 @@ def get_params(self) -> tuple:
639658 (tuple) : tuple
640659 A tuple containing the current parameters of the model.
641660 """
642- self .optimizer .zero_grad (set_to_none = True )
643- return tuple (self .model .parameters ())
661+ if self .optimizer is not None :
662+ self .optimizer .zero_grad (set_to_none = True )
663+ if self .model is not None :
664+ return tuple (self .model .parameters ())
665+ return ()
644666
645667 def get_all_loss_accuray (self ) -> list :
646668 """
0 commit comments