Skip to content

[Possible Bug] Running multiple initialization seeds #107

@Abzinger

Description

@Abzinger

Hi,

I wanted to train a network with multiple initial seeds such as this

cfg = TrainConfig(sae_config, optimizer=optimizer, init_seeds=[0,42], 
                batch_size=b_size, layers=['6'], 
                run_name=run_name, save_dir=save_dir)
trainer = Trainer(cfg, tokenized, gpt)

However I get the error KeyError: 'h.6' from line 375 raw = self.saes[name], see log below

File "script_training_sae.py", line 97, in <module>
    trainer.fit()
  File "sparsify/sparsify/trainer.py", line 481, in fit
    self.model(x)
  File "lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1062, in forward
    transformer_outputs = self.transformer(
                          ^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 922, in forward
    outputs = block(
              ^^^^^^
  File "lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "lib/python3.11/site-packages/torch/nn/modules/module.py", line 1845, in _call_impl
    return inner()
           ^^^^^^^
  File "/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1806, in inner
    hook_result = hook(self, args, result)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
  File "sparsify/sparsify/trainer.py", line 375, in hook
    raw = self.saes[name]
          ~~~~~~~~~^^^^^^
KeyError: 'h.6'

Same error when training using hookpoints. (I was able to successfully train on layer 6 with a single initial_seed)

Possible Bug

I think the problem might be the following:
On line 77
name = f"{hook}/seed{seed}" if len(cfg.init_seeds) > 1 else hook
self.saes[name] = SparseCoder(
input_widths[hook], cfg.sae, device, dtype=torch.float32
)
the dictionary sae is defined with the keys name = f"{hook}/seed{seed}"
and apparently in fit() (line 375) raw=self.saes[name] takes only f"{hook}". The problem could be that the variable name is being reset from f"{hook}/seed{seed}" to f"{hook} in lines 328:

name_to_module = {
            name: self.model.base_model.get_submodule(name)
            for name in self.cfg.hookpoints
        }

I hope the error is reproducible.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinggood first issueGood for newcomers

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions