Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions sparsify/sparse_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,11 @@ def forward(
auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False)

# Encourage the top ~50% of dead latents to predict the residual of the
# top k living latents
e_hat = self.decode(auxk_acts, auxk_indices)
# top k living latents. We call decoder_impl directly rather than
# self.decode because the residual target e already accounts for b_dec
# (sae_out includes it), so adding b_dec again here would double-count it.
assert self.W_dec is not None, "Decoder weight was not initialized."
e_hat = decoder_impl(auxk_indices, auxk_acts.to(self.dtype), self.W_dec.mT)
auxk_loss = (e_hat - e.detach()).pow(2).sum()
auxk_loss = scale * auxk_loss / total_variance
else:
Expand Down
61 changes: 61 additions & 0 deletions tests/test_auxk_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch

from sparsify import SparseCoder, SparseCoderConfig
from sparsify.utils import decoder_impl


def test_auxk_loss_does_not_double_count_b_dec():
"""The AuxK loss target ``e = y - sae_out`` already accounts for ``b_dec``
(since ``sae_out`` includes it), so the second decoder pass used to compute
``e_hat`` must *not* add ``b_dec`` again. See issue #132.

This runs on CPU using the eager decoder fallback, so it requires no GPU.
"""
torch.manual_seed(0)

d_in = 16
num_latents = 32
k = 4
batch = 8

sae = SparseCoder(
d_in,
SparseCoderConfig(num_latents=num_latents, k=k),
)

# Give b_dec a nonzero value; this is the normal case after init centers it
# on the data mean, and is exactly the situation the bug affects.
with torch.no_grad():
sae.b_dec.copy_(torch.randn(d_in))

x = torch.randn(batch, d_in)
dead_mask = torch.ones(num_latents, dtype=torch.bool)

out = sae(x, dead_mask=dead_mask)

# Recompute the AuxK loss by hand, decoding *without* re-adding b_dec.
top_acts, top_indices, pre_acts = sae.encode(x)
sae_out = sae.decode(top_acts, top_indices)
e = x - sae_out
total_variance = (x - x.mean(0)).pow(2).sum()

num_dead = int(dead_mask.sum())
k_aux = x.shape[-1] // 2
scale = min(num_dead / k_aux, 1.0)
k_aux = min(k_aux, num_dead)

auxk_latents = torch.where(dead_mask[None], pre_acts, -torch.inf)
auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False)

# Correct target: decode without adding b_dec a second time.
assert sae.W_dec is not None
e_hat = decoder_impl(auxk_indices, auxk_acts.to(sae.dtype), sae.W_dec.mT)
expected_auxk_loss = scale * (e_hat - e.detach()).pow(2).sum() / total_variance

torch.testing.assert_close(out.auxk_loss, expected_auxk_loss)

# Sanity check: the buggy formulation (re-adding b_dec) gives a *different*
# value when b_dec != 0, so this test would actually fail without the fix.
buggy_e_hat = e_hat + sae.b_dec
buggy_auxk_loss = scale * (buggy_e_hat - e.detach()).pow(2).sum() / total_variance
assert not torch.allclose(out.auxk_loss, buggy_auxk_loss)