Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions clrs/_src/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
hidden_dim: int = 32,
encode_hints: bool = False,
decode_hints: bool = True,
compute_hint_loss: bool = True,
encoder_init: str = 'default',
use_lstm: bool = False,
learning_rate: float = 0.005,
Expand Down Expand Up @@ -175,6 +176,7 @@ def __init__(
message-passing vectors.
encode_hints: Whether to provide hints as model inputs.
decode_hints: Whether to provide hints as model outputs.
compute_hint_loss: Whether to add hint loss to the output loss.
encoder_init: The initialiser type to use for the encoders.
use_lstm: Whether to insert an LSTM after message passing.
learning_rate: Learning rate for training.
Expand Down Expand Up @@ -202,15 +204,19 @@ def __init__(

Raises:
ValueError: if `encode_hints=True` and `decode_hints=False`.
if `compute_hint_loss=True` and `decode_hints=False`.
"""
super(BaselineModel, self).__init__(spec=spec)

if encode_hints and not decode_hints:
raise ValueError('`encode_hints=True`, `decode_hints=False` is invalid.')
if compute_hint_loss and not decode_hints:
raise ValueError('`compute_hint_loss=True`, `decode_hints=False` is invalid.')

assert hint_repred_mode in ['soft', 'hard', 'hard_on_eval']

self.decode_hints = decode_hints
self.compute_hint_loss = compute_hint_loss
self.checkpoint_path = checkpoint_path
self.name = name
self._freeze_processor = freeze_processor
Expand Down Expand Up @@ -414,7 +420,7 @@ def _loss(self, params, rng_key, feedback, algorithm_index):
)

# Optionally accumulate hint losses.
if self.decode_hints:
if self.compute_hint_loss:
for truth in feedback.features.hints:
total_loss += losses.hint_loss(
truth=truth,
Expand Down Expand Up @@ -455,7 +461,7 @@ def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]:
losses_ = {}

# Optionally accumulate hint losses.
if self.decode_hints:
if self.compute_hint_loss:
for truth in feedback.features.hints:
losses_.update(
losses.hint_loss(
Expand Down Expand Up @@ -591,7 +597,7 @@ def _loss(self, params, rng_key, feedback, mp_state, algorithm_index):
)

# Optionally accumulate hint losses.
if self.decode_hints:
if self.compute_hint_loss:
for truth in feedback.features.hints:
loss = losses.hint_loss_chunked(
truth=truth,
Expand Down
16 changes: 13 additions & 3 deletions clrs/examples/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
'during training instead of predicted hints. Only '
'pertinent in encoded_decoded modes.')
flags.DEFINE_enum('hint_mode', 'encoded_decoded',
['encoded_decoded', 'decoded_only', 'none'],
['encoded_decoded', 'decoded_only', 'none_encoded_decoded', 'none'],
'How should hints be used? Note, each mode defines a '
'separate task, with various difficulties. `encoded_decoded` '
'requires the model to explicitly materialise hint sequences '
Expand All @@ -85,7 +85,9 @@
'note that we currently do not make any efforts to '
'counterbalance the various hint losses. Hence, for certain '
'tasks, the best performance will now be achievable with no '
'hint usage at all (`none`).')
'hint usage at all (`none`). `none_encoded_decoded` produces'
'similar computation graphs to `encoded_decoded` but without'
'hint usage.')
flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'],
'How to process predicted hints when fed back as inputs.'
'In soft mode, we use softmaxes for categoricals, pointers '
Expand Down Expand Up @@ -365,14 +367,21 @@ def main(unused_argv):
if FLAGS.hint_mode == 'encoded_decoded':
encode_hints = True
decode_hints = True
compute_hint_loss = True
elif FLAGS.hint_mode == 'decoded_only':
encode_hints = False
decode_hints = True
compute_hint_loss = True
elif FLAGS.hint_mode == 'none_encoded_decoded':
encode_hints = True
decode_hints = True
compute_hint_loss = False
elif FLAGS.hint_mode == 'none':
encode_hints = False
decode_hints = False
compute_hint_loss = False
else:
raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.')
raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none_encoded_decoded, none}.')

train_lengths = [int(x) for x in FLAGS.train_lengths]

Expand All @@ -396,6 +405,7 @@ def main(unused_argv):
hidden_dim=FLAGS.hidden_size,
encode_hints=encode_hints,
decode_hints=decode_hints,
compute_hint_loss=compute_hint_loss,
encoder_init=FLAGS.encoder_init,
use_lstm=FLAGS.use_lstm,
learning_rate=FLAGS.learning_rate,
Expand Down