Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepcase/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
batch_size = args.batch,
learning_rate = 0.01,
teach_ratio = 0.5,
delta = args.delta,
verbose = not args.silent,
)

Expand Down
14 changes: 11 additions & 3 deletions deepcase/context_builder/context_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def forward(self, X, y=None, steps=1, teach_ratio=0.5):
########################################################################

def fit(self, X, y, epochs=10, batch_size=128, learning_rate=0.01,
optimizer=optim.SGD, teach_ratio=0.5, verbose=True):
optimizer=optim.SGD, teach_ratio=0.5, delta=0.1, verbose=True):
"""Fit the sequence predictor with labelled data

Parameters
Expand All @@ -222,6 +222,9 @@ def fit(self, X, y, epochs=10, batch_size=128, learning_rate=0.01,
teach_ratio : float, default=0.5
Ratio of sequences to train including labels.

delta : float, default=0.1
Label smoothing factor to apply during training.

verbose : boolean, default=True
If True, prints progress.

Expand All @@ -243,7 +246,7 @@ def fit(self, X, y, epochs=10, batch_size=128, learning_rate=0.01,
self.train()

# Set criterion and optimiser
criterion = LabelSmoothing(self.decoder_event.out.out_features, 0.1)
criterion = LabelSmoothing(self.decoder_event.out.out_features, delta)
optimizer = optimizer(
params = self.parameters(),
lr = learning_rate
Expand Down Expand Up @@ -356,7 +359,8 @@ def predict(self, X, y=None, steps=1):


def fit_predict(self, X, y, epochs=10, batch_size=128, learning_rate=0.01,
optimizer=optim.SGD, teach_ratio=0.5, verbose=True):
optimizer=optim.SGD, teach_ratio=0.5, delta=0.1,
verbose=True):
"""Fit the sequence predictor with labelled data

Parameters
Expand All @@ -382,6 +386,9 @@ def fit_predict(self, X, y, epochs=10, batch_size=128, learning_rate=0.01,
teach_ratio : float, default=0.5
Ratio of sequences to train including labels

delta : float, default=0.1
Label smoothing factor to apply during training

verbose : boolean, default=True
If True, prints progress

Expand All @@ -401,6 +408,7 @@ def fit_predict(self, X, y, epochs=10, batch_size=128, learning_rate=0.01,
learning_rate = learning_rate,
optimizer = optimizer,
teach_ratio = teach_ratio,
delta = delta,
verbose = verbose,
).predict(X)

Expand Down
10 changes: 10 additions & 0 deletions deepcase/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def fit(self,
learning_rate = 0.01,
optimizer = optim.SGD,
teach_ratio = 0.5,
delta = 0.1,

# Interpreter-specific parameters
iterations = 100,
Expand Down Expand Up @@ -132,6 +133,9 @@ def fit(self,
teach_ratio : float, default=0.5
Ratio of sequences to train including labels.

delta : float, default=0.1
Label smoothing factor to apply during ContextBuilder training.

iterations : int, default=100
Number of iterations for query.

Expand Down Expand Up @@ -169,6 +173,7 @@ def fit(self,
learning_rate = learning_rate,
optimizer = optimizer,
teach_ratio = teach_ratio,
delta = delta,
verbose = verbose,
)

Expand Down Expand Up @@ -249,6 +254,7 @@ def fit_predict(self,
learning_rate = 0.01,
optimizer = optim.SGD,
teach_ratio = 0.5,
delta = 0.1,

# Interpreter-specific parameters
iterations = 100,
Expand Down Expand Up @@ -291,6 +297,9 @@ def fit_predict(self,
teach_ratio : float, default=0.5
Ratio of sequences to train including labels.

delta : float, default=0.1
Label smoothing factor to apply during ContextBuilder training.

iterations : int, default=100
Number of iterations for query.

Expand Down Expand Up @@ -335,6 +344,7 @@ def fit_predict(self,
learning_rate = learning_rate,
optimizer = optimizer,
teach_ratio = teach_ratio,
delta = delta,
iterations = iterations,
query_batch_size = query_batch_size,
strategy = strategy,
Expand Down
1 change: 1 addition & 0 deletions docs/source/usage/code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ Once the ``context_builder`` is created, we train it using the :py:meth:`fit()`
epochs = 10, # Number of epochs to train with
batch_size = 128, # Number of samples in each training batch, in paper this was 128
learning_rate = 0.01, # Learning rate to train with, in paper this was 0.01
delta = 0.1, # Label smoothing factor
verbose = True, # If True, prints progress
)

Expand Down