Skip to content

Commit 7f7b660

Browse files
authored
Merge pull request #11 from Arief-AK/optimisation
Optimisation: Trained models to reach higher than 85% accuracy
2 parents b2845c5 + 24ddbe5 commit 7f7b660

54 files changed

Lines changed: 69 additions & 99 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/runner_requirements.txt

Lines changed: 0 additions & 66 deletions
This file was deleted.

RealtimeClassification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import tensorflow as tf
44
import time
55

6-
from tensorflow.keras.models import load_model
6+
from tensorflow.keras.models import load_model # type: ignore
77
from include.Logger import Logger
88

99
class LiveCameraClassifier:
@@ -56,6 +56,6 @@ def run(self):
5656
# Example Usage
5757
if __name__ == "__main__":
5858
class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
59-
model = load_model("models/batch_norm_model.h5")
59+
model = load_model("models/batch_norm_model_rmsprop.keras")
6060
classifier = LiveCameraClassifier(model, class_names)
6161
classifier.run()

Train.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,19 @@
66
from include.TensorModel import TensorModel
77
from include.ModelProfiler import ModelProfiler
88

9-
NUM_EPOCHS = 5
9+
NUM_EPOCHS = 10
1010
BATCH_FITTING = 128
1111
BATCH_PROFILING = [32, 64, 128]
1212

1313
MODELS = ["base_model", "batch_norm_model", "batch_norm_model_sgd", "batch_norm_model_rmsprop"]
1414
USE_EXISTING_MODELS = True
1515

1616
SAVE_MODELS = True
17-
SAVE_MODELS_AS_H5 = True
17+
SAVE_MODELS_AS_H5 = False
1818
SAVE_MODELS_AS_KERAS = True
19-
SAVE_MODELS_AS_SavedModel = True
19+
SAVE_MODELS_AS_SavedModel = False
20+
21+
PROFILE_MODELS = True
2022

2123
def create_predicition_matrix(model_handler: TensorModel, visualiser: Visualiser, model, x_test, y_test, str_model):
2224
conf_matrix = model_handler.compute_confusion_matrix(model, x_test, y_test)
@@ -28,9 +30,9 @@ def create_predicition_matrix(model_handler: TensorModel, visualiser: Visualiser
2830

2931
def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualiser, logger: Logger, x_train, y_train, x_test, y_test, batch_size) -> tuple:
3032
# Check if model exists
31-
if USE_EXISTING_MODELS and os.path.exists(f"models/{model_name}.h5"):
32-
model = model = tf.keras.models.load_model(f"models/{model_name}.h5")
33-
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
33+
if USE_EXISTING_MODELS and os.path.exists(f"models/{model_name}.keras"):
34+
model = model = tf.keras.models.load_model(f"models/{model_name}.keras")
35+
model.compile(optimizer="sgd", loss="categorical_crossentropy", metrics=["accuracy"])
3436
model.build()
3537
model.summary()
3638
else:
@@ -39,8 +41,8 @@ def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualis
3941
else:
4042
model = model_handler.create_cnn(batch_normalisation=True)
4143

42-
history = model.fit(x_train, y_train, epochs=NUM_EPOCHS, batch_size=batch_size, validation_data=(x_test, y_test))
43-
test_loss, test_acc = model.evaluate(x_test, y_test)
44+
history = model.fit(x_train, y_train, steps_per_epoch=x_train.shape[0], epochs=NUM_EPOCHS, batch_size=batch_size, validation_data=(x_test, y_test))
45+
test_loss, test_acc = model.evaluate(x_test, y_test, batch_size=batch_size)
4446

4547
if SAVE_MODELS:
4648
if SAVE_MODELS_AS_H5:
@@ -52,7 +54,7 @@ def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualis
5254
if SAVE_MODELS_AS_SavedModel:
5355
model.export(f"models/{model_name}_saved_model")
5456

55-
logger.info(f"Model accuracy: {test_acc * 100:.2f}%")
57+
logger.info(f"{model_name} Model accuracy: {test_acc * 100:.2f}%")
5658
visualiser.plot_training_history(history, model_name)
5759
return (model, model_name), (test_acc)
5860

@@ -72,7 +74,7 @@ def profile_models(model_acc_results:dict, visualiser: Visualiser, logger: Logge
7274
# Load models
7375
for model_name, accuracy in model_acc_results.items():
7476
# Load the model and data
75-
model = model_handler.load_model(f"models/{model_name}.h5")
77+
model = model_handler.load_model(f"models/{model_name}.keras")
7678
(_, _), (x_test, _) = model_handler.load_data()
7779
(batch_time, throughput_time), (single_image_time) = profiler.measure_average_inference_time(batch_size, model, x_test, show_single_image_inference=True)
7880

@@ -108,6 +110,8 @@ def profile_models(model_acc_results:dict, visualiser: Visualiser, logger: Logge
108110

109111
# Train and profile models
110112
model_acc_results = train_models(model_handler, visualiser, logger, x_train, y_train, x_test, y_test, model_acc_results)
111-
profile_models(model_acc_results, visualiser, logger)
113+
114+
if PROFILE_MODELS:
115+
profile_models(model_acc_results, visualiser, logger)
112116

113117
logger.info("Done!")
498 Bytes
388 Bytes
84 Bytes
-966 Bytes
4.3 KB
-364 Bytes
-616 Bytes

0 commit comments

Comments
 (0)