Skip to content

Add experimental causal-robust Toto finetuning path#88

Open
DhyeyMavani2003 wants to merge 1 commit intoDataDog:mainfrom
DhyeyMavani2003:dmavani/causal-robust-toto
Open

Add experimental causal-robust Toto finetuning path#88
DhyeyMavani2003 wants to merge 1 commit intoDataDog:mainfrom
DhyeyMavani2003:dmavani/causal-robust-toto

Conversation

@DhyeyMavani2003
Copy link

Summary

This PR introduces an experimental causal-robust fine-tuning path for Toto, inspired by the causal-transformer margin/barrier formulation discussed in Causal_transformers_theory.pdf.

The implementation is intentionally backward-compatible and off by default.

Motivation

Recent causal-transformer theory suggests stabilizing sequence models under interventions by enforcing a positive robustness margin and penalizing near-violations with a barrier term.

In Toto, we add a practical approximation using attention-time statistics so we can:

  • improve robustness under temporal distribution shifts,
  • preserve current behavior when disabled,
  • expose clean tuning knobs for experimentation.

What Changed

  1. Attention-level robustness penalty (time-wise only)
  • File: toto/model/attention.py
  • Added optional causal-robustness configuration and tracking:
    • configure_causal_robustness(enabled, alpha, eps, max_penalty)
    • latest_causal_robustness_penalty
  • Added differentiable penalty computation from attention statistics:
    • Compute attention scores and weights over time
    • Estimate attention-weighted variance trace on values
    • Define margin: margin = 1 - alpha * variance_trace
    • Apply log barrier: penalty = -log(clamp(margin, min=eps))
    • Cap penalty by max_penalty for stability
  • Penalty is computed only when:
    • causal robustness is enabled,
    • model is in training mode,
    • attention axis is time.
  1. Transformer aggregation of per-layer penalties
  • File: toto/model/transformer.py
  • Threaded robustness params through TransformerLayer/Transformer construction.
  • Added Transformer.configure_causal_robustness(...) to toggle/tune all layers.
  • Aggregates layer penalties into self.latest_causal_robustness_penalty (mean across layers).
  1. Backbone exposure/config passthrough
  • File: toto/model/backbone.py
  • Added constructor args for causal robustness options.
  • Exposed latest_causal_robustness_penalty at backbone level.
  • Added TotoBackbone.configure_causal_robustness(...) delegating to transformer.
  1. Fine-tuning loss integration + logging
  • File: toto/model/lightning_module.py
  • Added fine-tune hyperparameters:
    • causal_robust_lambda
    • causal_robust_alpha
    • causal_robust_eps
    • causal_robust_max_penalty
  • If causal_robust_lambda > 0, enables robustness regularization on model init.
  • In train/val step, adds weighted penalty to objective:
    • total_loss += causal_robust_lambda * latest_causal_robustness_penalty
  • Added metrics logging:
    • train/val_causal_robustness_penalty
    • train/val_causal_robustness_penalty_weighted
  1. Config plumbing + docs
  • Files:
    • toto/scripts/finetune_toto.py
    • toto/scripts/configs/finetune_config.yaml
    • README.md
  • Added config keys and README section for enabling/tuning the feature.
  1. Tests
  • File: toto/test/model/causal_robustness_test.py
  • Added tests for:
    • default-disabled behavior,
    • enabled penalty finite/non-negative behavior,
    • weighted penalty contributing to fine-tuning loss.
  • Includes module-level skip guard for environments missing required deps.

Backward Compatibility

  • Feature is fully optional.
  • Default configuration (causal_robust_lambda: 0.0) leaves existing behavior unchanged.

Usage

Example config:

model:
  causal_robust_lambda: 0.02
  causal_robust_alpha: 0.0001
  causal_robust_eps: 0.000001
  causal_robust_max_penalty: 20.0

Validation

  • python -m py_compile on all touched Python files
  • ⚠️ pytest toto/test/model/causal_robustness_test.py is skipped in this environment due upstream dependency/import issues (torchvision/lightning import chain), not due assertion failures in new logic

Limitations / Follow-ups

  • This PR uses an attention-statistics proxy for causal robustness (not a full Jacobian/log-det barrier objective).
  • Follow-up work can add a tighter approximation and benchmark impact across FEV/BOOM settings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant