diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index 97e30d7a..485b1443 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -144,6 +144,7 @@ def __init__( hidden_dim: int = 32, encode_hints: bool = False, decode_hints: bool = True, + compute_hint_loss: bool = True, encoder_init: str = 'default', use_lstm: bool = False, learning_rate: float = 0.005, @@ -175,6 +176,7 @@ def __init__( message-passing vectors. encode_hints: Whether to provide hints as model inputs. decode_hints: Whether to provide hints as model outputs. + compute_hint_loss: Whether to add hint loss to the output loss. encoder_init: The initialiser type to use for the encoders. use_lstm: Whether to insert an LSTM after message passing. learning_rate: Learning rate for training. @@ -202,15 +204,19 @@ def __init__( Raises: ValueError: if `encode_hints=True` and `decode_hints=False`. + if `compute_hint_loss=True` and `decode_hints=False`. """ super(BaselineModel, self).__init__(spec=spec) if encode_hints and not decode_hints: raise ValueError('`encode_hints=True`, `decode_hints=False` is invalid.') + if compute_hint_loss and not decode_hints: + raise ValueError('`compute_hint_loss=True`, `decode_hints=False` is invalid.') assert hint_repred_mode in ['soft', 'hard', 'hard_on_eval'] self.decode_hints = decode_hints + self.compute_hint_loss = compute_hint_loss self.checkpoint_path = checkpoint_path self.name = name self._freeze_processor = freeze_processor @@ -414,7 +420,7 @@ def _loss(self, params, rng_key, feedback, algorithm_index): ) # Optionally accumulate hint losses. - if self.decode_hints: + if self.compute_hint_loss: for truth in feedback.features.hints: total_loss += losses.hint_loss( truth=truth, @@ -455,7 +461,7 @@ def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]: losses_ = {} # Optionally accumulate hint losses. - if self.decode_hints: + if self.compute_hint_loss: for truth in feedback.features.hints: losses_.update( losses.hint_loss( @@ -591,7 +597,7 @@ def _loss(self, params, rng_key, feedback, mp_state, algorithm_index): ) # Optionally accumulate hint losses. - if self.decode_hints: + if self.compute_hint_loss: for truth in feedback.features.hints: loss = losses.hint_loss_chunked( truth=truth, diff --git a/clrs/examples/run.py b/clrs/examples/run.py index 2173e8d3..59c100bd 100644 --- a/clrs/examples/run.py +++ b/clrs/examples/run.py @@ -73,7 +73,7 @@ 'during training instead of predicted hints. Only ' 'pertinent in encoded_decoded modes.') flags.DEFINE_enum('hint_mode', 'encoded_decoded', - ['encoded_decoded', 'decoded_only', 'none'], + ['encoded_decoded', 'decoded_only', 'none_encoded_decoded', 'none'], 'How should hints be used? Note, each mode defines a ' 'separate task, with various difficulties. `encoded_decoded` ' 'requires the model to explicitly materialise hint sequences ' @@ -85,7 +85,9 @@ 'note that we currently do not make any efforts to ' 'counterbalance the various hint losses. Hence, for certain ' 'tasks, the best performance will now be achievable with no ' - 'hint usage at all (`none`).') + 'hint usage at all (`none`). `none_encoded_decoded` produces' + 'similar computation graphs to `encoded_decoded` but without' + 'hint usage.') flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'], 'How to process predicted hints when fed back as inputs.' 'In soft mode, we use softmaxes for categoricals, pointers ' @@ -365,14 +367,21 @@ def main(unused_argv): if FLAGS.hint_mode == 'encoded_decoded': encode_hints = True decode_hints = True + compute_hint_loss = True elif FLAGS.hint_mode == 'decoded_only': encode_hints = False decode_hints = True + compute_hint_loss = True + elif FLAGS.hint_mode == 'none_encoded_decoded': + encode_hints = True + decode_hints = True + compute_hint_loss = False elif FLAGS.hint_mode == 'none': encode_hints = False decode_hints = False + compute_hint_loss = False else: - raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.') + raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none_encoded_decoded, none}.') train_lengths = [int(x) for x in FLAGS.train_lengths] @@ -396,6 +405,7 @@ def main(unused_argv): hidden_dim=FLAGS.hidden_size, encode_hints=encode_hints, decode_hints=decode_hints, + compute_hint_loss=compute_hint_loss, encoder_init=FLAGS.encoder_init, use_lstm=FLAGS.use_lstm, learning_rate=FLAGS.learning_rate,