Variance-based attribution of LLM outputs to layer-local components via matched-pair interchange interventions. This repository is the reference implementation for the factual recall setting (GPT-2 + CounterFact).
Given a prompt like "The Eiffel Tower is located in", which transformer layer and submodule (Attn or MLP) carries the most information about the correct continuation? IGSD answers this with two numbers per layer-local group: a total Sobol index
STand a main indexS1, computed by swapping single-layer activations between matched donor / source prompts and measuring the change inlogit(correct) - logit(foil).
git clone https://github.com/gyf9712/igsd
cd igsd
pip install -e .Tested with Python ≥ 3.9, PyTorch ≥ 2.0, transformer_lens ≥ 2.0.
The built-in factual prompt set is enough to run an end-to-end demo on GPT-2 small in about a minute on CPU:
python examples/toy_run.pyThis writes results_toy/igsd_factual_gpt2-small.json and a matching
.png of ST / S1 per layer-local group.
Download CounterFact JSON to data/counterfact.json — see
data/README.md.
igsd-factual \
--model gpt2-small \
--counterfact data/counterfact.json \
--n_prompts 2000 --m_sobol 2000 \
--batch 16 --seed 42 \
--output_dir results
python scripts/plot_attribution.py results/igsd_factual_gpt2-small.jsonEach pair (A, B) is assigned to one of three disjoint buckets:
| Bucket | subject_A vs B | relation_A vs B |
|---|---|---|
same_subj |
same | different |
same_rel |
different | same |
diff_both |
different | different |
For a chosen group (e.g. MLP_L0) the script reports ST_swap (donor
activation patched in) and ST_zero (activation zeroed out), with 95%
bootstrap CIs:
igsd-factorial \
--model gpt2-small \
--counterfact data/counterfact.json \
--target_group MLP_L0 \
--target_per_bucket 200 \
--output_dir resultsPre-registered interpretation:
ST(same_subj) < ST(same_rel)→ subject-dominant (group carries subject-identity information).ST(same_rel) < ST(same_subj)→ relation-dominant (group carries relation information).- both
< ST(diff_both)and roughly equal → mixed. - neither smaller than
ST(diff_both)→ confound-like (group is generically fragile).
Comparing ST_swap to ST_zero separates "transported content matters"
from "generic fragility".
Features. For every prompt i we compute Y_i = logit(correct) - logit(foil)
at the last prompt token, plus per-(layer, head) direct logit attribution DLA
and per-layer MLP attribution MLP_DLA, all projected onto the
(correct - foil) direction through ln_final. The pairing feature vector
is the standardized concatenation [Y, sum_h DLA[:, l, h] + MLP_DLA[:, l]].
Donor matching. Source idx_A is sampled uniformly from 0..N-1. Donor
idx_B is drawn uniformly from idx_A's top-k nearest neighbours in
standardized feature space, with a self-collision rejection step.
--donor_dist_thresh gates pairs by max distance to test mechanistic vs
boundary-instability effects.
Sobol. Groups are layer-local: Attn_L{l} and MLP_L{l} for every
layer l, so K = 2 * n_layers. For each pair (A, B) we run two patched
forward passes per group:
ST_k = mean( (Y_A - Y_patched_A_with_B_at_group_k)^2 ) / (2 * Var(Y_full))
S1_k = 1 - mean( (Y_B - Y_patched_B_with_A_at_group_k)^2 ) / (2 * Var(Y_full))
Var(Y_full) is taken over the IID prompt pool, not the matched A+B pool.
Self-interchange. --self_interchange forces idx_B = idx_A; every
patch is identity, so all ST_k must be ~0. Any non-zero ST flags a
hook-plumbing bug. Run this once per new model.
import igsd
model = igsd.load_model("gpt2-small")
prompts = igsd.load_factual_prompts(
counterfact_path="data/counterfact.json",
n_max=500,
tokenizer=model.tokenizer,
)
from igsd.features import extract_features, tokenize_factual
tokens = tokenize_factual(prompts, model.tokenizer)
feats = extract_features(model, tokens, batch=8)
pairs = igsd.nn_donor_pairs(
features_std=feats.feature_matrix,
m_pairs=500, k=10, seed=42,
)
groups = igsd.build_groups(model.cfg.n_layers)
result = igsd.interchange_sobol(
model,
padded=tokens.padded,
correct_ids=tokens.correct_ids,
foil_ids=tokens.foil_ids,
tgt_pos=tokens.tgt_pos,
Y_full=feats.Y,
idx_A=pairs.idx_A,
idx_B=pairs.idx_B,
groups=groups,
)
print(result.ST) # [K]
print(result.invalid_groups)pip install -e .[dev]
pytest -m "not slow" # fast unit tests
pytest # also runs the gpt2-small self-interchange sanityIf you use this code, please cite the IGSD paper (preprint reference TBD).
MIT. See LICENSE.