From 6106e0ea164faadade813f57819a48cf7774de7b Mon Sep 17 00:00:00 2001 From: FHof Date: Thu, 14 May 2020 18:23:44 +0200 Subject: [PATCH] Fix the self.model.predict crash I have applied the suggestion from https://github.com/tensorflow/tensorflow/issues/28287#issuecomment-495005162 I'm not familiar with tensorflow, so I don't know if I've fixed the problem correctly. Now the trainer is no longer stuck in the waiting state. --- sources/agent.py | 5 +++++ sources/trainer.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/sources/agent.py b/sources/agent.py index c2d6a98..0bdb734 100644 --- a/sources/agent.py +++ b/sources/agent.py @@ -22,6 +22,7 @@ import keras.backend.tensorflow_backend as backend from keras.optimizers import Adam from keras.models import load_model, Model +from keras.backend import set_session sys.stdin = stdin sys.stderr = stderr @@ -36,6 +37,8 @@ def __init__(self, model_path=False, id=None): # Set to show an output from Conv2D layer self.show_conv_cam = (id + 1) in settings.CONV_CAM_AGENTS + self.sess = tf.Session() + # Main model (agent does not use target model) self.model = self.create_model(prediction=True) @@ -48,6 +51,8 @@ def __init__(self, model_path=False, id=None): # Load or create a new model (loading a model is being used only when playing or by trainer class that inherits from agent) def create_model(self, prediction=False): + set_session(self.sess) + # If there is a patht to the model set, load model if self.model_path: model = load_model(self.model_path) diff --git a/sources/trainer.py b/sources/trainer.py index 366b4f9..872d8f5 100644 --- a/sources/trainer.py +++ b/sources/trainer.py @@ -21,6 +21,7 @@ import tensorflow as tf tf.logging.set_verbosity(tf.logging.ERROR) import keras.backend.tensorflow_backend as backend +from keras.backend import set_session sys.stdin = stdin sys.stderr = stderr @@ -29,6 +30,8 @@ class ARTDQNTrainer(ARTDQNAgent): def __init__(self, model_path): + self.sess = tf.Session() + # If model path is beiong passed in - use it instead of creating a new one self.model_path = model_path self.model = self.create_model() @@ -107,6 +110,7 @@ def train(self): current_states.append((np.array([[transition[0][1]] for transition in minibatch]) - 50) / 50) # We need to use previously saved graph here as this is going to be called from separate thread with self.graph.as_default(): + set_session(self.sess) current_qs_list = self.model.predict(current_states, settings.PREDICTION_BATCH_SIZE) # Get future states from minibatch, then query NN model for Q values