66from include .TensorModel import TensorModel
77from include .ModelProfiler import ModelProfiler
88
9- NUM_EPOCHS = 25
9+ NUM_EPOCHS = 5
1010BATCH_FITTING = 128
1111BATCH_PROFILING = [32 , 64 , 128 ]
1212
1313MODELS = ["base_model" , "batch_norm_model" , "batch_norm_model_sgd" , "batch_norm_model_rmsprop" ]
1414USE_EXISTING_MODELS = True
1515
16+ SAVE_MODELS = True
17+ SAVE_MODELS_AS_H5 = True
18+ SAVE_MODELS_AS_KERAS = True
19+ SAVE_MODELS_AS_SavedModel = True
20+
1621def create_predicition_matrix (model_handler : TensorModel , visualiser : Visualiser , model , x_test , y_test , str_model ):
1722 conf_matrix = model_handler .compute_confusion_matrix (model , x_test , y_test )
1823 visualiser .plot_confusion_matrix (conf_matrix , model_handler .get_class_names (), str_model )
@@ -22,8 +27,6 @@ def create_predicition_matrix(model_handler: TensorModel, visualiser: Visualiser
2227 visualiser .plot_diagonal_confusion_matrix (diagonal_matrix , model_handler .get_class_names (), str_model )
2328
2429def train_model (model_name :str , model_handler : TensorModel , visualiser : Visualiser , logger : Logger , x_train , y_train , x_test , y_test , batch_size ) -> tuple :
25- logger .info (f"Eager enabled: { tf .executing_eagerly ()} " )
26-
2730 # Check if model exists
2831 if USE_EXISTING_MODELS and os .path .exists (f"models/{ model_name } .h5" ):
2932 model = model = tf .keras .models .load_model (f"models/{ model_name } .h5" )
@@ -38,7 +41,16 @@ def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualis
3841
3942 history = model .fit (x_train , y_train , epochs = NUM_EPOCHS , batch_size = batch_size , validation_data = (x_test , y_test ))
4043 test_loss , test_acc = model .evaluate (x_test , y_test )
41- model .save (f"models/{ model_name } .h5" )
44+
45+ if SAVE_MODELS :
46+ if SAVE_MODELS_AS_H5 :
47+ model .save (f"models/{ model_name } .h5" )
48+
49+ if SAVE_MODELS_AS_KERAS :
50+ model .save (f"models/{ model_name } .keras" )
51+
52+ if SAVE_MODELS_AS_SavedModel :
53+ model .export (f"models/{ model_name } _saved_model" )
4254
4355 logger .info (f"Model accuracy: { test_acc * 100 :.2f} %" )
4456 visualiser .plot_training_history (history , model_name )
@@ -76,14 +88,14 @@ def profile_models(model_acc_results:dict, visualiser: Visualiser, logger: Logge
7688 logger .info (f"Accuracy:{ accuracy * 100 :.2f} %\n " )
7789
7890if __name__ == "__main__" :
79- # Initialise variables
91+ # # Initialise variables
8092 model_acc_results = {}
8193
82- # Create a logger
94+ # # Create a logger
8395 logger = Logger (__name__ )
8496 logger .set_level (logging .INFO )
8597
86- # Initialise visualiser
98+ # # Initialise visualiser
8799 visualiser = Visualiser ()
88100
89101 # Initalise model handler
0 commit comments