Add experimental causal-robust Toto finetuning path#88
Open
DhyeyMavani2003 wants to merge 1 commit intoDataDog:mainfrom
Open
Add experimental causal-robust Toto finetuning path#88DhyeyMavani2003 wants to merge 1 commit intoDataDog:mainfrom
DhyeyMavani2003 wants to merge 1 commit intoDataDog:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR introduces an experimental causal-robust fine-tuning path for Toto, inspired by the causal-transformer margin/barrier formulation discussed in
Causal_transformers_theory.pdf.The implementation is intentionally backward-compatible and off by default.
Motivation
Recent causal-transformer theory suggests stabilizing sequence models under interventions by enforcing a positive robustness margin and penalizing near-violations with a barrier term.
In Toto, we add a practical approximation using attention-time statistics so we can:
What Changed
toto/model/attention.pyconfigure_causal_robustness(enabled, alpha, eps, max_penalty)latest_causal_robustness_penaltymargin = 1 - alpha * variance_tracepenalty = -log(clamp(margin, min=eps))max_penaltyfor stabilitytoto/model/transformer.pyTransformerLayer/Transformerconstruction.Transformer.configure_causal_robustness(...)to toggle/tune all layers.self.latest_causal_robustness_penalty(mean across layers).toto/model/backbone.pylatest_causal_robustness_penaltyat backbone level.TotoBackbone.configure_causal_robustness(...)delegating to transformer.toto/model/lightning_module.pycausal_robust_lambdacausal_robust_alphacausal_robust_epscausal_robust_max_penaltycausal_robust_lambda > 0, enables robustness regularization on model init.total_loss += causal_robust_lambda * latest_causal_robustness_penaltytrain/val_causal_robustness_penaltytrain/val_causal_robustness_penalty_weightedtoto/scripts/finetune_toto.pytoto/scripts/configs/finetune_config.yamlREADME.mdtoto/test/model/causal_robustness_test.pyBackward Compatibility
causal_robust_lambda: 0.0) leaves existing behavior unchanged.Usage
Example config:
Validation
python -m py_compileon all touched Python filespytest toto/test/model/causal_robustness_test.pyis skipped in this environment due upstream dependency/import issues (torchvision/lightning import chain), not due assertion failures in new logicLimitations / Follow-ups