diff --git a/README.md b/README.md index a940d34..caaebda 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ -# EdgeLake Project +# Federated Learning -- AnyLog EdgeLake + + +- pip install -r requirements.txt ### Software Requirements - Python 3.7+ installed. @@ -30,6 +33,7 @@ git clone https://github.com/EdgeLake/EdgeLake - **master.env**: Ensure this file is available for configuration. - **start-script.al**: This script is necessary for proper configuration. - **Example CURL commands file**: Keep this handy for testing and reference. +- **/blockchain/.env**: Used to specify system variables, file paths, etc. Update with values corresponding to your system. --- @@ -63,7 +67,113 @@ git clone https://github.com/EdgeLake/EdgeLake ### Step 5: Testing with CURL Commands ---- + +''' +CURL REQUEST FOR DEPLOYING CONTRACT-- General Form + +curl -X POST http://localhost:[AGGREGATOR_PORT]/init \ +-H "Content-Type: application/json" \ +-d '{ + "nodeUrls": [ + "http://localhost:[NODE_0_PORT]", + "http://localhost:[NODE_1_PORT]" + ], + "model_path": "[FILE_PATH_TO_PYTORCH_MODEL_SOURCE_CODE]", + "model_init_params": [OPTIONAL, IF YOUR PYTORCH MODEL HAS ANY], + "model_name": "[MODEL_NAME]", + "model_weights_path": "[WHERE YOU WANT MODEL WEIGHTS SAVED]", + "data_handler_path": "[FILE_PATH_TO_DATA_HANDLER_SOURCE_CODE]", + "data_config": {"data": ["[DATA_FILE_FOR_NODE_0]", + "[DATA_FILE_FOR_NODE_1]]} +}' +''' + + +''' +CURL REQUEST FOR DEPLOYING CONTRACT-- custom dataset and model, 1 node + +curl -X POST http://localhost:8080/init \ +-H "Content-Type: application/json" \ +-d '{ + "nodeUrls": [ + "http://localhost:8081" + ], + "model_path": "C:\\Users\\nehab\\cse115d\\testmodel.py", + "model_init_params": { "module__input_dim": 14 }, + "model_name": "custom_test", + "model_weights_path": "C:\\Users\\nehab\\cse115d\\model_weights.pt", + "data_handler_path": "C:\\Users\\nehab\\cse115d_anylog_edgelake\\custom_data_handler.py", + "data_config": {"data": ["C:\\Users\\nehab\\cse115d_anylog_edgelake\\heart_data\\party_data\\party_0.csv"]} + +}' +''' + +''' +CURL REQUEST FOR DEPLOYING CONTRACT-- custom dataset and model, 2 nodes + +curl -X POST http://localhost:8080/init \ +-H "Content-Type: application/json" \ +-d '{ + "nodeUrls": [ + "http://localhost:8081", + "http://localhost:8082" + ], + "model_path": "C:\\Users\\nehab\\cse115d\\testmodel.py", + "model_init_params": { "module__input_dim": 14 }, + "model_name": "custom_test", + "model_weights_path": "C:\\Users\\nehab\\cse115d\\model_weights.pt", + "data_handler_path": "C:\\Users\\nehab\\cse115d_anylog_edgelake\\custom_data_handler.py", + "data_config": {"data": ["C:\\Users\\nehab\\cse115d_anylog_edgelake\\heart_data\\party_data\\party_0.csv", + "C:\\Users\\nehab\\cse115d_anylog_edgelake\\heart_data\\party_data\\party_1.csv"]} +}' +''' + +''' +CURL REQUEST FOR DEPLOYING CONTRACT-- mnist built in dataset and model, 2 nodes + +curl -X POST http://localhost:8080/init \ +-H "Content-Type: application/json" \ +-d '{ + "nodeUrls": [ + "http://localhost:8081", + "http://localhost:8082" + ], + "model_path": "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\federated-learning-lib-main\\examples\\iter_avg\\model_pytorch.py", + "model_name": "mnist_test", + "model_weights_path": "C:\\Users\\nehab\\cse115d\\model_weights.pt", + "data_handler_path": "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\venv38\\Lib\\site-packages\\ibmfl\\util\\data_handlers\\mnist_pytorch_data_handler.py", + "data_config": { + "npz_file": [ + "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\data\\mnist\\data_party0.npz", + "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\data\\mnist\\data_party1.npz" + ] + } +}' +''' + +''' +CURL REQUEST FOR DEPLOYING CONTRACT-- mnist dataset and sql model, 2 nodes + +curl -X POST http://localhost:8080/init \ +-H "Content-Type: application/json" \ +-d '{ + "nodeUrls": [ + "http://localhost:8081", + "http://localhost:8082" + ], + "model_path": "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\federated-learning-lib-main\\examples\\iter_avg\\model_pytorch.py", + "model_name": "mnist_test", + "model_weights_path": "C:\\Users\\nehab\\cse115d\\model_weights.pt", + "data_handler_path": "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\custom_sql_datahandler.py", + "data_config": { + "npz_file": [ + "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\data\\mnist\\data_party0.npz", + "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\data\\mnist\\data_party1.npz" + ] + } +}' +''' + ## Automating CURL Commands @@ -90,6 +200,7 @@ for cmd in curl_commands: ### Instructions for Script + --- ## Troubleshooting @@ -102,4 +213,3 @@ for cmd in curl_commands: - Verify the EdgeLake service is running. - Check network connectivity and server status. ---- diff --git a/blockchain/.env b/blockchain/.env index d554853..e217c3d 100644 --- a/blockchain/.env +++ b/blockchain/.env @@ -7,6 +7,7 @@ DATABASE_URL=https://anylog-edgelake-fl-default-rtdb.firebaseio.com # Ethereum provider URL PROVIDER_URL=https://optimism-sepolia.infura.io/v3/6fce3361490c4187b59947005a07c3e7 + # Ethereum private key PRIVATE_KEY=f155acda1fc73fa6f50456545e3487b78fd517411708ffa1f67358c1d3d54977 @@ -27,3 +28,4 @@ EXTERNAL_IP=10.0.0.171 # LOCAL PSQL DB NAME PSQL_DB_NAME=mnist_fl + diff --git a/blockchain/aggregator.py b/blockchain/aggregator.py index 83595c8..8efa2ae 100644 --- a/blockchain/aggregator.py +++ b/blockchain/aggregator.py @@ -13,7 +13,7 @@ from ibmfl.util.data_handlers.mnist_pytorch_data_handler import MnistPytorchDataHandler -CONTRACT_ADDRESS = "0x4ae311B85B017bf7EAa7a96D3109f58795F5F4BF" +import time load_dotenv() @@ -51,6 +51,7 @@ def start_round(self, initParamsLink, roundNumber): external_ip = os.getenv("EXTERNAL_IP") url = f'http://{external_ip}:32049' + # in khaled's: for node_num in range(1, int(minParams) + 1) headers = { 'User-Agent': 'AnyLog/1.23', 'Content-Type': 'text/plain', @@ -58,14 +59,33 @@ def start_round(self, initParamsLink, roundNumber): } # Format data exactly like the example curl command but with your values - # NOTE: ask why are we adding the node num from agg - data = f'''''' - print(f"Training initialized with {roundNumber} rounds") - + # retries = 0; + # max_retries = 5; + # while retries < max_retries: + # response = requests.post(url, headers=headers, data=data) + # if response.status_code == 200: + # print(f"Aggregator has submitted parameters for round {roundNumber} to the blockchain.") + # return { + # 'status': 'success', + # 'message': 'Aggregator model parameters added successfully' + # } + # else: + # print(f"Failed to add aggregator params to blockchain. Response: {response}. Retrying ({retries + 1}/{max_retries})...") + # retries += 1; + # time.sleep(15); + + # return { + # 'status': 'error', + # 'message': 'aggregator was unable to add to blockchain' + # } + response = requests.post(url, headers=headers, data=data) + print(response.status_code) if response.status_code == 200: return { 'status': 'success', @@ -145,6 +165,7 @@ def decode_params(self, encoded_model_update): model_weights = pickle.loads(serialized_data) return model_weights - def inference(self, data): - results = self.fusion_model.fl_model.predict(data) + + def inference(self, model, data): + results = model.evaluate(data) return results diff --git a/blockchain/aggregator_server.py b/blockchain/aggregator_server.py index 27d6877..ded2bff 100644 --- a/blockchain/aggregator_server.py +++ b/blockchain/aggregator_server.py @@ -8,6 +8,15 @@ import requests import os +import torch + +import firebase_admin +from firebase_admin import credentials, db + +import base64 + +from ibmfl.model.pytorch_fl_model import PytorchFLModel + app = Flask(__name__) load_dotenv() @@ -19,19 +28,89 @@ aggregator = Aggregator(PROVIDER_URL, PRIVATE_KEY) ''' -CURL REQUEST FOR DEPLOYING CONTRACT +CURL REQUEST FOR DEPLOYING CONTRACT-- custom dataset and model, 1 node + +curl -X POST http://localhost:8080/init \ +-H "Content-Type: application/json" \ +-d '{ + "nodeUrls": [ + "http://localhost:8081" + ], + "model_path": "C:\\Users\\nehab\\cse115d\\testmodel.py", + "model_init_params": { "module__input_dim": 14 }, + "model_name": "custom_test", + "model_weights_path": "C:\\Users\\nehab\\cse115d\\model_weights.pt", + "data_handler_path": "C:\\Users\\nehab\\cse115d_anylog_edgelake\\custom_data_handler.py", + "data_config": {"data": ["C:\\Users\\nehab\\cse115d_anylog_edgelake\\heart_data\\party_data\\party_0.csv"]} + +}' +''' + +''' +CURL REQUEST FOR DEPLOYING CONTRACT-- custom dataset and model, 2 nodes curl -X POST http://localhost:8080/init \ -H "Content-Type: application/json" \ -d '{ "nodeUrls": [ - "http://localhost:8081", + "http://localhost:8081", "http://localhost:8082" ], - "model_def": 1 + "model_path": "C:\\Users\\nehab\\cse115d\\testmodel.py", + "model_init_params": { "module__input_dim": 14 }, + "model_name": "custom_test", + "model_weights_path": "C:\\Users\\nehab\\cse115d\\model_weights.pt", + "data_handler_path": "C:\\Users\\nehab\\cse115d_anylog_edgelake\\custom_data_handler.py", + "data_config": {"data": ["C:\\Users\\nehab\\cse115d_anylog_edgelake\\heart_data\\party_data\\party_0.csv", + "C:\\Users\\nehab\\cse115d_anylog_edgelake\\heart_data\\party_data\\party_1.csv"]} }' ''' +''' +CURL REQUEST FOR DEPLOYING CONTRACT-- mnist built in dataset and model, 2 nodes + +curl -X POST http://localhost:8080/init \ +-H "Content-Type: application/json" \ +-d '{ + "nodeUrls": [ + "http://localhost:8081", + "http://localhost:8082" + ], + "model_path": "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\federated-learning-lib-main\\examples\\iter_avg\\model_pytorch.py", + "model_name": "mnist_test", + "model_weights_path": "C:\\Users\\nehab\\cse115d\\model_weights.pt", + "data_handler_path": "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\venv38\\Lib\\site-packages\\ibmfl\\util\\data_handlers\\mnist_pytorch_data_handler.py", + "data_config": { + "npz_file": [ + "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\data\\mnist\\data_party0.npz", + "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\data\\mnist\\data_party1.npz" + ] + } +}' +''' + +''' +CURL REQUEST FOR DEPLOYING CONTRACT-- mnist dataset and sql model, 2 nodes + +curl -X POST http://localhost:8080/init \ +-H "Content-Type: application/json" \ +-d '{ + "nodeUrls": [ + "http://localhost:8081", + "http://localhost:8082" + ], + "model_path": "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\federated-learning-lib-main\\examples\\iter_avg\\model_pytorch.py", + "model_name": "mnist_test", + "model_weights_path": "C:\\Users\\nehab\\cse115d\\model_weights.pt", + "data_handler_path": "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\custom_sql_datahandler.py", + "data_config": { + "npz_file": [ + "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\data\\mnist\\data_party0.npz", + "C:\\Users\\nehab\\cse115d\\Anylog-Edgelake-CSE115D\\blockchain\\data\\mnist\\data_party1.npz" + ] + } +}' +''' @app.route('/init', methods=['POST']) def deploy_contract(): @@ -39,17 +118,151 @@ def deploy_contract(): try: data = request.json node_urls = data.get('nodeUrls', []) - model_def = data.get('model_def', 1) + + model_path = data.get('model_path', os.getenv('MODEL_PYTHON')) + model_init_params = data.get('model_init_params', None) + model_name = data.get('model_name', 'model') + model_weights_path = data.get('model_weights_path') + + # upload model to firebase + firebase_model_path = f"/models/{model_name}" + pytorch_upload(model_path, model_init_params, model_name, firebase_model_path, model_weights_path) + + + data_handler_path = data.get('data_handler_path') + data_config = data.get('data_config') + + # upload datahandler to firebase + firebase_datahandler_path = f"/datahandlers/datahandler" + datahandler_upload(data_handler_path, data_config, firebase_datahandler_path) # Initialize the nodes and send the contract address - initialize_nodes(model_def, node_urls) + initialize_nodes(firebase_model_path, firebase_datahandler_path, node_urls) except Exception as e: return jsonify({'status': 'error', 'message': str(e)}), 500 - return f"Initialized nodes with model definition: {model_def}", 200 - - -def initialize_nodes(model_def, node_urls): + return f"Initialized nodes with model definition: {model_path}", 200 + +# creates and uploads a model to firebase for the nodes to download +def pytorch_upload(model_file_path, model_init_params, model_name, firebase_model_path, model_weights_path): + + # Read the model class source code + with open(model_file_path, "r") as f: + model_source_code = f.read() + + # Dynamically load the model class + namespace = {} + exec(model_source_code, namespace) + + # Identify the model class dynamically-- this searches for all classes that are a subset of nn.Module + model_class = None + for obj_name, obj in namespace.items(): + if isinstance(obj, type) and issubclass(obj, torch.nn.Module) and obj != torch.nn.Module: + model_class = obj + break + + if model_class is None: + print("No PyTorch model class found in the specified file.") + + print("Searching for nn.Sequential object...") + + # set up fields necessary for get_model_config + + # temporary folder for model serialization + folder_configs = os.path.join(os.getcwd(), "model"); + model_weights_path = os.path.join(folder_configs, "pytorch_sequence.pt"); + dataset = None; # this isn't even used in the model so let's skip for now + is_agg = False; # this is being uploaded for nodes to use, we want an actual model + party_id = None; # also isn't even used + + get_model_config = namespace['get_model_config'] + model_config = get_model_config(folder_configs, dataset, is_agg, party_id) + + if model_config is None or "spec" not in model_config: + raise ValueError("Failed to retrieve a valid model configuration.") + + #print("Model Config: ", model_config) + + # build model_specs + model_specs = model_config['spec'] + print("model specs: ", model_specs) + + # Initialize PytorchFLModel + fl_model = PytorchFLModel( + model_name="Pytorch_NN", + model_spec=model_specs + ) + + + else: + print("Identified model class:", model_class) + + # Initialize PytorchFLModel + fl_model = PytorchFLModel( + model_name="Pytorch_NN", + pytorch_module=model_class, + module_init_params=model_init_params, + ) + + # Save model weights + model_weights_path = os.path.join(os.getcwd(), "model\\pytorch_sequence.pt"); + fl_model.save_model(filename=model_weights_path) + + model_specs = None + + + # Encode the source code and model weights + encoded_source_code = encode_to_base64(model_source_code) + with open(model_weights_path, "rb") as f: + encoded_weights = encode_to_base64(f.read()) + + # define model info to upload + model_data = { + "source_code": encoded_source_code, + "weights": encoded_weights, + "init_params": model_init_params, + "model_spec": model_specs + } + + firebase_model_ref = db.reference(firebase_model_path) + firebase_model_ref.set(model_data) + print(f"PytorchFLModel uploaded to Firebase Realtime Database at {firebase_model_path}.") + +# creates and uploads a datahandler for nodes to download +def datahandler_upload(datahandler_file_path, data_config, firebase_datahandler_path): + + # read source code + with open(datahandler_file_path, "r") as f: + datahandler_source_code = f.read() + + # Encode the source code and configuration + encoded_source_code = encode_to_base64(datahandler_source_code) + encoded_data_config = encode_to_base64(str(data_config)) + + # Prepare the data to upload + datahandler_data = { + "source_code": encoded_source_code, + "data_config": encoded_data_config, + } + + # upload to firebase + datahandler_firebase_ref = db.reference(firebase_datahandler_path) + datahandler_firebase_ref.set(datahandler_data) + print(f"DataHandler uploaded to Firebase Realtime Database at {firebase_datahandler_path}") + + +def encode_to_base64(data): + """Encode binary or text data to Base64.""" + if isinstance(data, bytes): + return base64.b64encode(data).decode("utf-8") + return base64.b64encode(data.encode("utf-8")).decode("utf-8") + +def decode_from_base64(data): + """Decode Base64 data to binary or text.""" + return base64.b64decode(data) + + +def initialize_nodes(firebase_model_path, firebase_datahandler_path, node_urls): """Send the deployed contract address to multiple node servers.""" for urlCount in range(len(node_urls)): try: @@ -57,8 +270,9 @@ def initialize_nodes(model_def, node_urls): print(f"Sending contract address to node at {url}") response = requests.post(f'{url}/init-node', json={ - 'replica_name': f"node{urlCount+1}", - 'model_def': model_def + 'replica_name': f"node{urlCount}", + 'firebase_model_path': firebase_model_path, + 'firebase_datahandler_path': firebase_datahandler_path }) # TODO: figure out how to handle response @@ -80,7 +294,7 @@ def initialize_nodes(model_def, node_urls): -H "Content-Type: application/json" \ -d '{ "totalRounds": 5, - "minParams": 1 + "minParams": 2 }' ''' @@ -88,6 +302,7 @@ def initialize_nodes(model_def, node_urls): @app.route('/start-training', methods=['POST']) async def init_training(): """Start the training process by setting the number of rounds.""" + print('entered start_training endpoint') try: data = request.json num_rounds = data.get('totalRounds', 1) @@ -103,7 +318,7 @@ async def init_training(): for r in range(1, num_rounds + 1): print(f"Starting round {r}") aggregator.start_round(initialParams, r) - # print("Sent initial parameters to nodes") + print("Finished start_round function") # Listen for updates from nodes newAggregatorParams = await listen_for_update_agg(min_params, r) # print("Received aggregated parameters") @@ -148,10 +363,10 @@ async def listen_for_update_agg(min_params, roundNumber): while True: try: - # Check parameter count + # Check parameter count for node responses count_response = requests.get(url, headers={ 'User-Agent': 'AnyLog/1.23', - "command": f"blockchain get a{roundNumber} count" + "command": f"blockchain get r{roundNumber} count" }) if count_response.status_code == 200: @@ -162,17 +377,19 @@ async def listen_for_update_agg(min_params, roundNumber): if count >= min_params: params_response = requests.get(url, headers={ 'User-Agent': 'AnyLog/1.23', - "command": f"blockchain get a{roundNumber}" + "command": f"blockchain get r{roundNumber}" }) if params_response.status_code == 200: result = params_response.json() + # print(f"blockchain get r{roundNumber} returns {result}") + if result and len(result) > 0: # Extract all trained_params into a list node_params_links = [ - item[f'a{roundNumber}']['trained_params'] + item[f'r{roundNumber}']['trained_params'] for item in result - if f'a{roundNumber}' in item + if f'r{roundNumber}' in item ] # print(f"Collected trained_params links: {node_params_links}") # Debugging line @@ -188,6 +405,28 @@ async def listen_for_update_agg(min_params, roundNumber): await asyncio.sleep(2) +# added inference endpoint +@app.route('/inference', methods=['POST']) +def inference(): + """Inference on current model w/ data passed in.""" + try: + data = request.json + test_data = data.get('data', {}) + + results = aggregator.inference(aggregator.fusion_model.fl_model, test_data) + + response = { + 'status': 'success', + 'message': 'Inference completed successfully', + 'model_accuracy': results['acc'] * 100, + 'classification_report': results['classificatio_report'] + } + + return jsonify(response) + + except Exception as e: + return jsonify({'status': 'error', 'message': str(e)}), 500 + if __name__ == '__main__': # Add argument parsing to make the port configurable diff --git a/blockchain/node.py b/blockchain/node.py index ff26f04..672c342 100644 --- a/blockchain/node.py +++ b/blockchain/node.py @@ -11,15 +11,21 @@ from ibmfl.model.pytorch_fl_model import PytorchFLModel from custom_data_handler import CustomMnistPytorchDataHandler import requests +import torch +import time # import pathlib from dotenv import load_dotenv load_dotenv() +def decode_from_base64(data): + """Decode Base64 data to binary or text.""" + return base64.b64decode(data) + class Node: - def __init__(self, model_def, replica_name): + def __init__(self, firebase_model_path, firebase_datahandler_path, replica_name): print("Node initializing") @@ -39,31 +45,127 @@ def __init__(self, model_def, replica_name): self.currentRound = 1 - current_dir = os.path.dirname(os.path.abspath(__file__)) - - # USE MNIST DATASET FOR TESTING THIS FUNCTIONALITY - data_path = os.path.join(current_dir, "data", "mnist", "data_party0.npz") - data_config = { - "npz_file": str(data_path) - } + # download model from firebase + self.fl_model = self.load_firebase_model(firebase_model_path); + # download datahandler from firebase + self.data_handler = self.load_firebase_datahandler(firebase_datahandler_path); + + # create the local training handler + self.local_training_handler = LocalTrainingHandler(fl_model=self.fl_model, data_handler=self.data_handler) + + + + + def load_firebase_model(self, firebase_model_path): + + # get model_data from firebase + firebase_model_ref = db.reference(firebase_model_path) + model_data = firebase_model_ref.get() + if model_data is None: + print("Error: No model data found in Firebase.") + return + + print('Downloaded model_data from Firebase') + + # derive fields, then decode source code and weights + model_source_code = decode_from_base64(model_data["source_code"]).decode("utf-8") + model_weights = decode_from_base64(model_data["weights"]) + init_params = model_data.get("init_params", None) + model_specs = model_data.get("model_spec", None) + + # Save the downloaded weights to a file-- necessary to load the model later + downloaded_weights_path = "downloaded_model_weights.pt" + with open(downloaded_weights_path, "wb") as f: + f.write(model_weights) + + print('Saved weights to local file') + + # Dynamically recreate the model class + downloaded_namespace = {} + exec(model_source_code, downloaded_namespace) + + # Identify the model class in the downloaded namespace + downloaded_model_class = None + for obj_name, obj in downloaded_namespace.items(): + if isinstance(obj, type) and issubclass(obj, torch.nn.Module) and obj != torch.nn.Module: + downloaded_model_class = obj + break + + if downloaded_model_class is None: + print("No PyTorch model class found in the downloaded source code.") + + print("Creating nn.Sequential version...") + + # set up fields necessary for get_model_config + + fl_model = PytorchFLModel( + model_name="Pytorch_NN", + model_spec=model_specs + ) + else: + print("Recreated model class:", downloaded_model_class) + + # Reinitialize the PytorchFLModel and load the weights + fl_model = PytorchFLModel( + model_name="Pytorch_NN", + pytorch_module=downloaded_model_class, + module_init_params=init_params, + ) + + fl_model.load_model( + pytorch_module=downloaded_model_class, + model_filename=downloaded_weights_path, + module_init_params=init_params, + ) + + print("PytorchFLModel successfully reconstructed and loaded.") + return fl_model + + + def load_firebase_datahandler(self, firebase_datahandler_path): + + # get datahandler data from firebase + firebase_datahandler_ref = db.reference(firebase_datahandler_path) + datahandler_data = firebase_datahandler_ref.get() + if datahandler_data is None: + print("Error: No DataHandler data found in Firebase.") + return + + print('Datahandler data acquired from firebase') + + # decode the source code and configuration + datahandler_source_code = decode_from_base64(datahandler_data["source_code"]).decode("utf-8") + data_config = eval(decode_from_base64(datahandler_data["data_config"]).decode("utf-8")) + + # dynamically recreate the DataHandler class + namespace = {} + exec(datahandler_source_code, namespace) + + # find the datahandler class, which must be a subclass of DataHandler + datahandler_class = None + for obj_name, obj in namespace.items(): + if isinstance(obj, type) and issubclass(obj, namespace.get('DataHandler', object)) and obj != namespace['DataHandler']: + datahandler_class = obj + break + + if datahandler_class is None: + raise ValueError("No DataHandler subclass found in the downloaded source code.") + else: + print("Recreated DataHandler class:", datahandler_class) - # model_def == 1: PytorchFLModel - if model_def == 1: - model_path = os.path.join(current_dir, "configs", "node", "pytorch", "pytorch_sequence.pt") + # data config contains the data paths for all nodes + # usually, i think it'd make more sense for the nodes to set this in their local .env files + # but for now, data config contains all paths-- we can access this specific node's from the replica name + replicaNumber = int(self.replicaName[-1]) + key = next(iter(data_config)) + personal_data_config = {key: data_config[key][replicaNumber]} - model_spec = { - "loss_criterion": "nn.NLLLoss", - "model_definition": str(model_path), - "model_name": "pytorch-nn", - "optimizer": "optim.Adadelta" - } - - fl_model = PytorchFLModel(model_name="pytorch-nn", model_spec=model_spec) - data_handler = CustomMnistPytorchDataHandler(self.replicaName) - self.local_training_handler = LocalTrainingHandler(fl_model=fl_model, data_handler=data_handler) - # add more model defs in elifs below - # model_def == 2: Sklearn and so on + # initialize the datahandler with the configuration + data_handler = datahandler_class(data_config=personal_data_config) + print("DataHandler successfully reconstructed and initialized.") + return data_handler + ''' add_data_batch(data) - Adds passed in data to local storage @@ -80,8 +182,6 @@ def add_data_batch(self, data): ''' def add_node_params(self, round_number, newly_trained_params_db_link): - print("in add_node_params") - try: external_ip = os.getenv("EXTERNAL_IP") url = f'http://{external_ip}:32049' @@ -92,14 +192,36 @@ def add_node_params(self, round_number, newly_trained_params_db_link): 'command': 'blockchain insert where policy = !my_policy and local = true and blockchain = optimism' } - data = f'''''' + }} }}>''' + + # retries = 0; + # max_retries = 5; + # while retries < max_retries: + # response = requests.post(url, headers=headers, data=data) + # if response.status_code == 200: + # print(f"{self.replicaName} has submitted results for round {round_number}") + # return { + # 'status': 'success', + # 'message': 'node model parameters added successfully' + # } + # else: + # print(f"Failed to add node {self.replicaName} params to blockchain. Response: {response}. Retrying ({retries + 1}/{max_retries})...") + # retries += 1; + # time.sleep(15); + + # return { + # 'status': 'error', + # 'message': 'node was unable to add to blockchain' + # } - # print(f"Submitting results for round {round_number}") response = requests.post(url, headers=headers, data=data) - print(f"Results submitted for round {round_number} to {self.replicaName}") + print("response after addding node data to blockchain ", response); + + print(f"{self.replicaName} has submitted results for round {round_number}") return { 'status': 'success', @@ -119,13 +241,18 @@ def add_node_params(self, round_number, newly_trained_params_db_link): ''' def train_model_params(self, aggregator_model_params_db_link, round_number): - print(f"in train_model_params for round {round_number}") + print(f"Training for round {round_number}") + + weights = '' # First round initialization if round_number == 1: + #print('initializing weights, round1') weights = self.local_training_handler.fl_model.get_model_update() + #print("round 1 weights", weights) else: try: + #print('round1+, getting weights from aggregator') # Extract the key from the URL model_updates_key = aggregator_model_params_db_link.split('/')[-1].replace('.json', '') @@ -167,6 +294,8 @@ def train_model_params(self, aggregator_model_params_db_link, round_number): 'model_update': encoded_params }) + print('Pushed weights to Firebase') + return f"{self.database_url}/node_model_updates/{data_pushed.key}.json" def encode_model(self, model_update): @@ -181,6 +310,12 @@ def decode_params(self, encoded_model_update): model_weights = pickle.loads(serialized_data) return model_weights + + # modified to get test data from datahandler def inference(self, data): - results = self.local_training_handler.fl_model.predict(data) - return results \ No newline at end of file + data1 = self.data_handler.get_data() + print("got data from inference handler") + data_test = data1[1]; + results = self.fl_model.evaluate(data_test) + print("results ", results); + return results diff --git a/blockchain/node_server.py b/blockchain/node_server.py index e6e98fd..4a0d125 100644 --- a/blockchain/node_server.py +++ b/blockchain/node_server.py @@ -64,11 +64,10 @@ def init_node(): """Receive the contract address from the aggregator server.""" global node_instance, listener_thread, stop_listening_thread try: - model_def = request.json.get('model_def', 1) replica_name = request.json.get('replica_name') + firebase_model_path = request.json.get('firebase_model_path') + firebase_datahandler_path = request.json.get('firebase_datahandler_path') - if not model_def: - return jsonify({'status': 'error', 'message': 'No config provided'}), 400 if listener_thread and listener_thread.is_alive(): stop_listening_thread = True listener_thread.join(timeout=1) @@ -77,7 +76,7 @@ def init_node(): stop_listening_thread = False # Instantiate the Node class - node_instance = Node(model_def, replica_name) + node_instance = Node(firebase_model_path, firebase_datahandler_path, replica_name) node_instance.currentRound = 1 print(f"{replica_name} successfully initialized") @@ -118,33 +117,35 @@ def listen_for_start_round(nodeInstance, stop_event): url = f'{external_ip}:32049' # next_round = nodeInstance.currentRound + 1 - print(f"listening for start round {nodeInstance.currentRound}") + print(f"Listening for start of round {nodeInstance.currentRound}") headers = { 'User-Agent': 'AnyLog/1.23', - 'command': f'blockchain get r{nodeInstance.currentRound}' + 'command': f'blockchain get a{nodeInstance.currentRound}' } response = requests.get(f'http://{url}', headers=headers) + # check if aggregator's params have been posted if response.status_code == 200: data = response.json() - # print(f"Response Data: {data}") # Debugging line round_data = None for item in data: # Check if the key exists in the current dictionary - if f'r{nodeInstance.currentRound}' in item: - round_data = item[f'r{nodeInstance.currentRound}'] + if f'a{nodeInstance.currentRound}' in item: + round_data = item[f'a{nodeInstance.currentRound}'] break # Stop searching once the current round's data is found if round_data: print(f"Round Data: {round_data}") # Debugging line paramsLink = round_data.get('initParams', '') modelUpdate = nodeInstance.train_model_params(paramsLink, nodeInstance.currentRound) + # print(modelUpdate); nodeInstance.add_node_params(nodeInstance.currentRound, modelUpdate) nodeInstance.currentRound += 1 - # else: # Debugging line - # print(f"No data found for round r{nodeInstance.currentRound}") + + else: + print(f"No aggregator parameters found for round {nodeInstance.currentRound}") time.sleep(2) # Poll every 2 seconds @@ -152,6 +153,34 @@ def listen_for_start_round(nodeInstance, stop_event): print(f"Error in listener thread: {str(e)}") time.sleep(2) +#inference untested +''' +curl -X POST http://localhost:8082/inference \ +-H "Content-Type: application/json" \ +-d '{ +}' +''' +@app.route('/inference', methods=['POST']) +def inference(): + """Inference on current model w/ data passed in.""" + try: + data = request.json + test_data = data.get('data', {}) + + results = node_instance.inference(test_data) + + print(results) + + response = { + 'status': 'success', + 'message': 'Inference completed successfully', + 'model_accuracy': results['accuracy_score'] * 100, + } + + return jsonify(response) + + except Exception as e: + return jsonify({'status': 'error', 'message': str(e)}), 500 @app.route('/inference', methods=['POST']) def inference():