1+ import torch
2+ import torch .nn as nn
3+ import tensorflow as tf
4+ import numpy as np
5+ import networkx as nx
6+
7+ # Simple feature extraction for graph classification
8+ def extract_graph_features (edges ):
9+ """Extract basic graph features from edge list"""
10+ if not edges :
11+ return np .zeros (64 ) # Return zero vector if no edges
12+
13+ # Create graph
14+ G = nx .Graph ()
15+ for edge in edges :
16+ if len (edge ) >= 2 :
17+ G .add_edge (str (edge [0 ]), str (edge [1 ]))
18+
19+ if G .number_of_nodes () == 0 :
20+ return np .zeros (64 )
21+
22+ # Basic graph features
23+ features = []
24+
25+ # Node count
26+ features .append (G .number_of_nodes ())
27+ # Edge count
28+ features .append (G .number_of_edges ())
29+ # Density
30+ features .append (nx .density (G ))
31+ # Average clustering
32+ features .append (nx .average_clustering (G ))
33+ # Number of connected components
34+ features .append (nx .number_connected_components (G ))
35+
36+ # Degree statistics
37+ degrees = [d for n , d in G .degree ()]
38+ if degrees :
39+ features .extend ([
40+ np .mean (degrees ),
41+ np .std (degrees ),
42+ np .max (degrees ),
43+ np .min (degrees )
44+ ])
45+ else :
46+ features .extend ([0 , 0 , 0 , 0 ])
47+
48+ # Pad or truncate to 64 features
49+ features = features [:64 ]
50+ while len (features ) < 64 :
51+ features .append (0.0 )
52+
53+ return np .array (features , dtype = np .float32 )
54+
55+ def convert_pth_to_h5 (pth_path , h5_path ):
56+ # Load PyTorch model state dict
57+ state_dict = torch .load (pth_path , map_location = 'cpu' )
58+
59+ # Create TensorFlow model that mimics the GCN structure
60+ # Input: 64-dimensional graph features
61+ tf_model = tf .keras .Sequential ([
62+ tf .keras .layers .Dense (64 , activation = 'relu' , input_shape = (64 ,), name = 'conv1' ),
63+ tf .keras .layers .Dense (64 , activation = 'relu' , name = 'conv2' ),
64+ tf .keras .layers .Dense (64 , activation = 'relu' , name = 'conv3' ),
65+ tf .keras .layers .Dense (3 , activation = 'softmax' , name = 'classifier' )
66+ ])
67+
68+ # Build the model
69+ tf_model .build ((None , 64 ))
70+
71+ # Extract and set weights from PyTorch model
72+ try :
73+ # Set conv1 weights (assuming it's the first GCN layer)
74+ conv1_weight = state_dict ['conv1.lin.weight' ].numpy ().T # Transpose for TF
75+ conv1_bias = state_dict ['conv1.bias' ].numpy ()
76+ tf_model .layers [0 ].set_weights ([conv1_weight , conv1_bias ])
77+
78+ # Set conv2 weights
79+ conv2_weight = state_dict ['conv2.lin.weight' ].numpy ().T
80+ conv2_bias = state_dict ['conv2.bias' ].numpy ()
81+ tf_model .layers [1 ].set_weights ([conv2_weight , conv2_bias ])
82+
83+ # Set conv3 weights
84+ conv3_weight = state_dict ['conv3.lin.weight' ].numpy ().T
85+ conv3_bias = state_dict ['conv3.bias' ].numpy ()
86+ tf_model .layers [2 ].set_weights ([conv3_weight , conv3_bias ])
87+
88+ # Set final linear layer weights
89+ lin_weight = state_dict ['lin.weight' ].numpy ().T
90+ lin_bias = state_dict ['lin.bias' ].numpy ()
91+ tf_model .layers [3 ].set_weights ([lin_weight , lin_bias ])
92+
93+ print ("Successfully transferred weights from PyTorch to TensorFlow" )
94+
95+ except Exception as e :
96+ print (f"Warning: Could not transfer all weights: { e } " )
97+ print ("Using randomly initialized weights" )
98+
99+ # Save TensorFlow model
100+ tf_model .save (h5_path )
101+ print (f"Model converted and saved to { h5_path } " )
102+
103+ if __name__ == "__main__" :
104+ pth_file = 'graph_classifier_model.pth'
105+ print (f"Converting { pth_file } to TensorFlow format..." )
106+ convert_pth_to_h5 (pth_file , 'model.h5' )
107+ print ("Conversion complete! You can now start the backend." )
0 commit comments