diff --git a/deepcase/context_builder/context_builder.py b/deepcase/context_builder/context_builder.py index 1793315..5429a4f 100644 --- a/deepcase/context_builder/context_builder.py +++ b/deepcase/context_builder/context_builder.py @@ -290,8 +290,8 @@ def fit(self, X, y, epochs=10, batch_size=128, learning_rate=0.01, optimizer.step() # Update description - total_loss += loss.item() / X_.shape[1] - total_items += X_.shape[0] + total_loss += loss.item() + total_items += X_.shape[0] * y_.shape[1] if verbose: data.set_description( diff --git a/deepcase/interpreter/interpreter.py b/deepcase/interpreter/interpreter.py index a7098a7..7f0d46b 100644 --- a/deepcase/interpreter/interpreter.py +++ b/deepcase/interpreter/interpreter.py @@ -333,9 +333,9 @@ def fit_predict(self, ).predict( X = X, y = y, - iterations = 100, - batch_size = 1024, - verbose = False, + iterations = iterations, + batch_size = batch_size, + verbose = verbose, ) ########################################################################