Skip to content

Commit b2845c5

Browse files
authored
Merge pull request #9 from Arief-AK/optimisation
Optimisation: Added more data augmentation
2 parents 47876cf + 707c775 commit b2845c5

52 files changed

Lines changed: 33 additions & 10 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 3 additions & 0 deletions

Train.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
from include.TensorModel import TensorModel
77
from include.ModelProfiler import ModelProfiler
88

9-
NUM_EPOCHS = 25
9+
NUM_EPOCHS = 5
1010
BATCH_FITTING = 128
1111
BATCH_PROFILING = [32, 64, 128]
1212

1313
MODELS = ["base_model", "batch_norm_model", "batch_norm_model_sgd", "batch_norm_model_rmsprop"]
1414
USE_EXISTING_MODELS = True
1515

16+
SAVE_MODELS = True
17+
SAVE_MODELS_AS_H5 = True
18+
SAVE_MODELS_AS_KERAS = True
19+
SAVE_MODELS_AS_SavedModel = True
20+
1621
def create_predicition_matrix(model_handler: TensorModel, visualiser: Visualiser, model, x_test, y_test, str_model):
1722
conf_matrix = model_handler.compute_confusion_matrix(model, x_test, y_test)
1823
visualiser.plot_confusion_matrix(conf_matrix, model_handler.get_class_names(), str_model)
@@ -22,8 +27,6 @@ def create_predicition_matrix(model_handler: TensorModel, visualiser: Visualiser
2227
visualiser.plot_diagonal_confusion_matrix(diagonal_matrix, model_handler.get_class_names(), str_model)
2328

2429
def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualiser, logger: Logger, x_train, y_train, x_test, y_test, batch_size) -> tuple:
25-
logger.info(f"Eager enabled: {tf.executing_eagerly()}")
26-
2730
# Check if model exists
2831
if USE_EXISTING_MODELS and os.path.exists(f"models/{model_name}.h5"):
2932
model = model = tf.keras.models.load_model(f"models/{model_name}.h5")
@@ -38,7 +41,16 @@ def train_model(model_name:str, model_handler: TensorModel, visualiser: Visualis
3841

3942
history = model.fit(x_train, y_train, epochs=NUM_EPOCHS, batch_size=batch_size, validation_data=(x_test, y_test))
4043
test_loss, test_acc = model.evaluate(x_test, y_test)
41-
model.save(f"models/{model_name}.h5")
44+
45+
if SAVE_MODELS:
46+
if SAVE_MODELS_AS_H5:
47+
model.save(f"models/{model_name}.h5")
48+
49+
if SAVE_MODELS_AS_KERAS:
50+
model.save(f"models/{model_name}.keras")
51+
52+
if SAVE_MODELS_AS_SavedModel:
53+
model.export(f"models/{model_name}_saved_model")
4254

4355
logger.info(f"Model accuracy: {test_acc * 100:.2f}%")
4456
visualiser.plot_training_history(history, model_name)
@@ -76,14 +88,14 @@ def profile_models(model_acc_results:dict, visualiser: Visualiser, logger: Logge
7688
logger.info(f"Accuracy:{accuracy * 100:.2f}%\n")
7789

7890
if __name__ == "__main__":
79-
# Initialise variables
91+
# # Initialise variables
8092
model_acc_results = {}
8193

82-
# Create a logger
94+
# # Create a logger
8395
logger = Logger(__name__)
8496
logger.set_level(logging.INFO)
8597

86-
# Initialise visualiser
98+
# # Initialise visualiser
8799
visualiser = Visualiser()
88100

89101
# Initalise model handler
244 Bytes
14 Bytes
219 Bytes
272 Bytes
-1.24 KB
-296 Bytes
-176 Bytes
72 Bytes

0 commit comments

Comments
 (0)