Skip to content

gyf9712/igsd

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

IGSD — Interchange-intervention Sobol Decomposition

Variance-based attribution of LLM outputs to layer-local components via matched-pair interchange interventions. This repository is the reference implementation for the factual recall setting (GPT-2 + CounterFact).

Given a prompt like "The Eiffel Tower is located in", which transformer layer and submodule (Attn or MLP) carries the most information about the correct continuation? IGSD answers this with two numbers per layer-local group: a total Sobol index ST and a main index S1, computed by swapping single-layer activations between matched donor / source prompts and measuring the change in logit(correct) - logit(foil).

Install

git clone https://github.com/gyf9712/igsd
cd igsd
pip install -e .

Tested with Python ≥ 3.9, PyTorch ≥ 2.0, transformer_lens ≥ 2.0.

Quickstart (no data download)

The built-in factual prompt set is enough to run an end-to-end demo on GPT-2 small in about a minute on CPU:

python examples/toy_run.py

This writes results_toy/igsd_factual_gpt2-small.json and a matching .png of ST / S1 per layer-local group.

Full reproduction (CounterFact)

Download CounterFact JSON to data/counterfact.json — see data/README.md.

Pick-freeze Sobol over the full factual set

igsd-factual \
    --model gpt2-small \
    --counterfact data/counterfact.json \
    --n_prompts 2000 --m_sobol 2000 \
    --batch 16 --seed 42 \
    --output_dir results
python scripts/plot_attribution.py results/igsd_factual_gpt2-small.json

Subject / relation factorial buckets

Each pair (A, B) is assigned to one of three disjoint buckets:

Bucket subject_A vs B relation_A vs B
same_subj same different
same_rel different same
diff_both different different

For a chosen group (e.g. MLP_L0) the script reports ST_swap (donor activation patched in) and ST_zero (activation zeroed out), with 95% bootstrap CIs:

igsd-factorial \
    --model gpt2-small \
    --counterfact data/counterfact.json \
    --target_group MLP_L0 \
    --target_per_bucket 200 \
    --output_dir results

Pre-registered interpretation:

  • ST(same_subj) < ST(same_rel)subject-dominant (group carries subject-identity information).
  • ST(same_rel) < ST(same_subj)relation-dominant (group carries relation information).
  • both < ST(diff_both) and roughly equal → mixed.
  • neither smaller than ST(diff_both)confound-like (group is generically fragile).

Comparing ST_swap to ST_zero separates "transported content matters" from "generic fragility".

How it works (one paragraph each)

Features. For every prompt i we compute Y_i = logit(correct) - logit(foil) at the last prompt token, plus per-(layer, head) direct logit attribution DLA and per-layer MLP attribution MLP_DLA, all projected onto the (correct - foil) direction through ln_final. The pairing feature vector is the standardized concatenation [Y, sum_h DLA[:, l, h] + MLP_DLA[:, l]].

Donor matching. Source idx_A is sampled uniformly from 0..N-1. Donor idx_B is drawn uniformly from idx_A's top-k nearest neighbours in standardized feature space, with a self-collision rejection step. --donor_dist_thresh gates pairs by max distance to test mechanistic vs boundary-instability effects.

Sobol. Groups are layer-local: Attn_L{l} and MLP_L{l} for every layer l, so K = 2 * n_layers. For each pair (A, B) we run two patched forward passes per group:

ST_k = mean( (Y_A - Y_patched_A_with_B_at_group_k)^2 ) / (2 * Var(Y_full))
S1_k = 1 - mean( (Y_B - Y_patched_B_with_A_at_group_k)^2 ) / (2 * Var(Y_full))

Var(Y_full) is taken over the IID prompt pool, not the matched A+B pool.

Self-interchange. --self_interchange forces idx_B = idx_A; every patch is identity, so all ST_k must be ~0. Any non-zero ST flags a hook-plumbing bug. Run this once per new model.

API

import igsd

model = igsd.load_model("gpt2-small")
prompts = igsd.load_factual_prompts(
    counterfact_path="data/counterfact.json",
    n_max=500,
    tokenizer=model.tokenizer,
)

from igsd.features import extract_features, tokenize_factual
tokens = tokenize_factual(prompts, model.tokenizer)
feats = extract_features(model, tokens, batch=8)

pairs = igsd.nn_donor_pairs(
    features_std=feats.feature_matrix,
    m_pairs=500, k=10, seed=42,
)

groups = igsd.build_groups(model.cfg.n_layers)
result = igsd.interchange_sobol(
    model,
    padded=tokens.padded,
    correct_ids=tokens.correct_ids,
    foil_ids=tokens.foil_ids,
    tgt_pos=tokens.tgt_pos,
    Y_full=feats.Y,
    idx_A=pairs.idx_A,
    idx_B=pairs.idx_B,
    groups=groups,
)
print(result.ST)         # [K]
print(result.invalid_groups)

Tests

pip install -e .[dev]
pytest -m "not slow"     # fast unit tests
pytest                    # also runs the gpt2-small self-interchange sanity

Citation

If you use this code, please cite the IGSD paper (preprint reference TBD).

License

MIT. See LICENSE.

About

Interchange-intervention Sobol decomposition for LLM mechanistic interpretability (factual recall reference implementation).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages