66from include .TensorModel import TensorModel
77from include .ModelProfiler import ModelProfiler
88
9- NUM_EPOCHS = 5
9+ NUM_EPOCHS = 10
1010BATCH_FITTING = 128
1111BATCH_PROFILING = [32 , 64 , 128 ]
1212
1313MODELS = ["base_model" , "batch_norm_model" , "batch_norm_model_sgd" , "batch_norm_model_rmsprop" ]
1414USE_EXISTING_MODELS = True
1515
1616SAVE_MODELS = True
17- SAVE_MODELS_AS_H5 = True
17+ SAVE_MODELS_AS_H5 = False
1818SAVE_MODELS_AS_KERAS = True
19- SAVE_MODELS_AS_SavedModel = True
19+ SAVE_MODELS_AS_SavedModel = False
20+
21+ PROFILE_MODELS = True
2022
2123def create_predicition_matrix (model_handler : TensorModel , visualiser : Visualiser , model , x_test , y_test , str_model ):
2224 conf_matrix = model_handler .compute_confusion_matrix (model , x_test , y_test )
@@ -28,9 +30,9 @@ def create_predicition_matrix(model_handler: TensorModel, visualiser: Visualiser
2830
2931def train_model (model_name :str , model_handler : TensorModel , visualiser : Visualiser , logger : Logger , x_train , y_train , x_test , y_test , batch_size ) -> tuple :
3032 # Check if model exists
31- if USE_EXISTING_MODELS and os .path .exists (f"models/{ model_name } .h5 " ):
32- model = model = tf .keras .models .load_model (f"models/{ model_name } .h5 " )
33- model .compile (optimizer = "adam " , loss = "categorical_crossentropy" , metrics = ["accuracy" ])
33+ if USE_EXISTING_MODELS and os .path .exists (f"models/{ model_name } .keras " ):
34+ model = model = tf .keras .models .load_model (f"models/{ model_name } .keras " )
35+ model .compile (optimizer = "sgd " , loss = "categorical_crossentropy" , metrics = ["accuracy" ])
3436 model .build ()
3537 model .summary ()
3638 else :
@@ -39,8 +41,8 @@ def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualis
3941 else :
4042 model = model_handler .create_cnn (batch_normalisation = True )
4143
42- history = model .fit (x_train , y_train , epochs = NUM_EPOCHS , batch_size = batch_size , validation_data = (x_test , y_test ))
43- test_loss , test_acc = model .evaluate (x_test , y_test )
44+ history = model .fit (x_train , y_train , steps_per_epoch = x_train . shape [ 0 ], epochs = NUM_EPOCHS , batch_size = batch_size , validation_data = (x_test , y_test ))
45+ test_loss , test_acc = model .evaluate (x_test , y_test , batch_size = batch_size )
4446
4547 if SAVE_MODELS :
4648 if SAVE_MODELS_AS_H5 :
@@ -52,7 +54,7 @@ def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualis
5254 if SAVE_MODELS_AS_SavedModel :
5355 model .export (f"models/{ model_name } _saved_model" )
5456
55- logger .info (f"Model accuracy: { test_acc * 100 :.2f} %" )
57+ logger .info (f"{ model_name } Model accuracy: { test_acc * 100 :.2f} %" )
5658 visualiser .plot_training_history (history , model_name )
5759 return (model , model_name ), (test_acc )
5860
@@ -72,7 +74,7 @@ def profile_models(model_acc_results:dict, visualiser: Visualiser, logger: Logge
7274 # Load models
7375 for model_name , accuracy in model_acc_results .items ():
7476 # Load the model and data
75- model = model_handler .load_model (f"models/{ model_name } .h5 " )
77+ model = model_handler .load_model (f"models/{ model_name } .keras " )
7678 (_ , _ ), (x_test , _ ) = model_handler .load_data ()
7779 (batch_time , throughput_time ), (single_image_time ) = profiler .measure_average_inference_time (batch_size , model , x_test , show_single_image_inference = True )
7880
@@ -108,6 +110,8 @@ def profile_models(model_acc_results:dict, visualiser: Visualiser, logger: Logge
108110
109111 # Train and profile models
110112 model_acc_results = train_models (model_handler , visualiser , logger , x_train , y_train , x_test , y_test , model_acc_results )
111- profile_models (model_acc_results , visualiser , logger )
113+
114+ if PROFILE_MODELS :
115+ profile_models (model_acc_results , visualiser , logger )
112116
113117 logger .info ("Done!" )
0 commit comments