FANS (Flow-based Analysis of Noise Shift) is a framework for detecting and analyzing distributional shifts in causal systems using normalizing flows. Built on the foundation of Causal Normalizing Flows, FANS extends the methodology to identify whether observed distribution changes are due to functional shifts (changes in causal mechanisms) or noise shifts (changes in noise distributions).
- Shift Detection: Automatically identifies which variables have undergone distributional shifts between environments
- Shift Type Classification: Distinguishes between function shifts and noise shifts
Create a new conda environment with Python 3.9.12:
conda create --name fans python=3.9.12 --no-default-packagesActivate the conda environment:
conda activate fansInstall PyTorch and related packages:
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117Install additional requirements:
pip install -r requirements.txtTrain a FANS model on a synthetic dataset with 10 nodes and Erdős-Rényi (ER) graph structure:
CUDA_VISIBLE_DEVICES=0 python main.py \
--config_file causal_nf/configs/data_small/nodes_10/ER/causal_nf_nodes_10_ER_adj_1.yaml \
--wandb_mode offline \
--project causal_nfWhat this does:
- Trains a causal normalizing flow on environment 1 data
- Evaluates shift detection performance on environment 2 data
- Saves results to
results/directory - Generates visualizations of detected shifts
The FANS (Flow-based Analysis of Noise Shift) method leverages trained causal normalizing flows to detect and classify distributional shifts between two environments.
-
Training: Learn a causal normalizing flow on environment 1 data that maps observations X to noise variables Z following the causal graph structure
-
Shift Detection: Transform environment 2 data through the learned flow and test for independence violations in the noise space
-
Statistical Testing: Use distance correlation and independence tests to identify shifted variables
-
Visualization: Generate comparative plots showing distributional differences
Synthetic datasets are organized by node count and graph type:
data/data_small/
├── nodes_10/
│ ├── ER/ # Erdős-Rényi random graphs
│ │ ├── adj_1.npy # Adjacency matrix
│ │ ├── data_env1_1.npy # Environment 1 data
│ │ ├── data_env2_1.npy # Environment 2 data
│ │ └── metadata_1.json # Shift information
│ └── SF/ # Scale-free graphs
├── nodes_20/
├── nodes_30/
├── nodes_40/
└── nodes_50/
- Morpho-MNIST: Located in
data/morpho_mnist/ - Sachs: Located in
data/sachs/
python main.py \
--config_file <CONFIG_PATH> \
--wandb_mode <MODE> \
--project <PROJECT_NAME>CUDA_VISIBLE_DEVICES=1 python main.py \
--config_file causal_nf/configs/data_small/nodes_30/SF/causal_nf_nodes_30_SF_adj_5.yaml \
--wandb_mode online \
--project fans_experimentsRun baseline shift detection methods for comparison:
python experiments/experiment_script.py --model <MODEL_NAME> [OPTIONS]Available Models:
splitkci: Kernel Conditional Independence Testprediter: PreDITEr methodiscan: Independence-based shift detectionlinearccp: Linear CCPgpr: Gaussian Process Regression
Options:
| Option | Description | Default |
|---|---|---|
--nodes |
Node counts to process (space-separated) | 10 20 30 40 50 |
--gpu |
GPU device ID (-1 for CPU) | -1 |
--output_dir |
Results directory | auto-generated |
--config_type |
Graph type filter (ER, SF, all) | all |
--dataset_indices |
Dataset range (e.g., "1-30") | all |
Run SplitKCI on all node sizes, only first 5 datasets:
python experiments/experiment_script.py \
--model splitkci \
--dataset_indices "1-5" \
--gpu 0Run ISCAN on CPU for SF graphs:
python experiments/experiment_script.py \
--model iscan \
--config_type SF \
--gpu -1Generate unified results CSV:
python experiments/analysis/analysis.py