Python package for identifying heterogeneous treatment effects on two outcomes using recursive partitioning to understand relationships and trade-offs.
Requirements: Python >= 3.9, < 3.13 (Python 3.10+ recommended for NumPy 2.x support)
Divergence Tree is a machine learning algorithm designed to identify heterogeneous treatment effects when you have two outcomes of interest and want to understand how treatment effects differ between them. The algorithm segments populations into regions where treatment effects on the two outcomes diverge or converge, revealing trade-offs and relationships.
In many contexts, treatments (e.g., interventions, policies, features) can have different effects on two outcomes of interest. Understanding these relationships and trade-offs is crucial for decision-making. Common applications include:
- Firm vs Consumer outcomes: A price increase might boost firm revenue but reduce consumer satisfaction
- Long-term vs Short-term outcomes: A marketing campaign might increase immediate sales but reduce long-term brand loyalty
- Efficiency vs Quality outcomes: Process changes might improve efficiency but reduce quality
- Any two competing or complementary outcomes: Where you need to understand how treatment effects vary across both dimensions
Divergence Tree segments your population into regions where treatment effects on the two outcomes diverge or converge, helping you identify:
- Win-win regions: Where treatments benefit both outcomes
- Trade-off regions: Where treatments help one outcome but harm the other
- Lose-lose regions: Where treatments harm both outcomes
The package provides two methods:
-
DivergenceTree: Directly optimizes a split objective function that measures divergence between treatment effects on the two outcomes. It grows a tree by recursively partitioning the feature space, then prunes splits that don't improve the objective.
-
TwoStepDivergenceTree: A two-step approach that first estimates treatment effects using causal forests, then trains a classification tree to predict region types based on those estimates.
Both methods categorize observations into 4 region types based on the signs of treatment effects:
- Region 1: τ₁ > 0, τ₂ > 0 (both positive - win-win)
- Region 2: τ₁ > 0, τ₂ ≤ 0 (first outcome positive, second negative - trade-off favoring first)
- Region 3: τ₁ ≤ 0, τ₂ > 0 (first outcome negative, second positive - trade-off favoring second)
- Region 4: τ₁ ≤ 0, τ₂ ≤ 0 (both negative - lose-lose)
Where τ₁ is the treatment effect on the first outcome and τ₂ is the treatment effect on the second outcome. In the firm/consumer example, these would be firm effects (τF) and consumer effects (τC).
- Download the package:
git clone https://github.com/ebzgr/divergence-tree
cd divergence-tree- Create a virtual environment:
Windows:
python -m venv .venv
.venv\Scripts\activateLinux/Mac:
python -m venv .venv
source .venv/bin/activate- Install the package (dependencies are installed automatically):
pip install -e .If you encounter installation issues, try:
# Option 1: Upgrade pip first
pip install --upgrade pip setuptools wheel
# Option 2: Install with exact versions (if flexible versions fail)
pip install -r requirements-exact.txt
pip install -e .
# Option 3: Install dependencies separately
pip install -r requirements.txt
pip install -e .Dependencies: numpy, pandas, matplotlib, optuna, scikit-learn, econml, scipy, joblib, lightgbm, shap, sparse, statsmodels
Common issues:
-
"No module named 'econml'" or missing dependencies:
- Ensure all dependencies are installed:
pip install -r requirements.txt - Some systems may need to install build tools (e.g.,
build-essentialon Linux)
- Ensure all dependencies are installed:
-
NumPy compatibility errors:
- If using Python < 3.10, you may need numpy < 2.0:
pip install "numpy<2.0" - NumPy 2.x requires Python 3.10+
- If using Python < 3.10, you may need numpy < 2.0:
-
LightGBM installation fails:
- Windows: May need Visual C++ Build Tools
- Linux/Mac: May need
cmakeandg++
-
Version conflicts:
- Use
requirements-exact.txtfor exact version matching - Or create a fresh virtual environment:
python -m venv .venv --clear
- Use
Before using the algorithms, prepare your data:
- X: Feature matrix of shape
(n_samples, n_features)- characteristics of each observation (e.g., user demographics, product features, time periods) - T: Treatment indicator of shape
(n_samples,)with values in {0, 1} - whether each observation received the treatment - YF: First outcome of shape
(n_samples,)- binary or continuous, may contain NaN (e.g., firm revenue, short-term sales, efficiency metrics) - YC: Second outcome of shape
(n_samples,)- binary or continuous, may contain NaN (e.g., consumer satisfaction, long-term loyalty, quality metrics)
Note: YF and YC can represent any two outcomes of interest. The naming (F/C) is a convention from the firm/consumer example, but you can use any two outcomes.
Outcome types (binary vs continuous) are automatically detected. NaN values are handled as missing data.
from divtree.tree import DivergenceTree
from divtree.tune import tune_with_optuna
# Optional: Tune hyperparameters using cross-validation
best_params, best_loss = tune_with_optuna(
X, T, YF, YC,
fixed={"lambda_": 1, "random_state": 0},
search_space={
"max_partitions": {"low": 4, "high": 15},
"min_improvement_ratio": {"low": 0.001, "high": 0.05, "log": True},
},
n_trials=20,
n_splits=5,
)
# Train the tree
tree = DivergenceTree(**best_params)
tree.fit(X, T, YF, YC)
# Predict region types (returns array of 1-4)
region_types = tree.predict_region_type(X)
# Get treatment effects for each leaf
leaf_effects = tree.leaf_effects()Key parameters:
max_partitions: Maximum leaves before pruning (default: 8)min_improvement_ratio: Minimum improvement to keep split (default: 0.01)lambda_: Co-movement weight in split objective (default: 1.0)regions_of_interest: List of region numbers (1-4) to focus on in the objective function (default: None, meaning all regions [1, 2, 3, 4] are of interest)- Region 1: τF > 0 and τC > 0 (both positive - win-win)
- Region 2: τF > 0 and τC ≤ 0 (firm+, customer-)
- Region 3: τF ≤ 0 and τC > 0 (firm-, customer+)
- Region 4: τF ≤ 0 and τC ≤ 0 (both negative - lose-lose)
Objective Function:
The split selection objective combines heterogeneity and co-movement:
- Heterogeneity: H = zF² + zC² (always included for all regions)
- Co-movement: d = zF * zC, with φ(d) = |d|
- Region weighting: w_region = 1.0 if region is in
regions_of_interest, else 0.0 - Objective: g = H + λ * w_region * φ(d)
where zF and zC are normalized deviations from baseline effects. The co-movement term is only weighted when the region belongs to regions_of_interest, allowing you to focus the optimization on specific regions while maintaining heterogeneity detection for all regions.
from twostepdivtree.tree import TwoStepDivergenceTree
# Initialize the two-step model
tree = TwoStepDivergenceTree(
causal_forest_params={"n_jobs": -1, "random_state": 0},
causal_forest_tune_params={"params": "auto"},
classification_tree_params={"random_state": 0},
)
# Train with optional classification tree tuning
tree.fit(
X, T, YF, YC,
auto_tune_classification_tree=True,
classification_tree_search_space={
"max_depth": {"low": 2, "high": 15},
"min_samples_split": {"low": 2, "high": 20},
},
classification_tree_tune_n_trials=30,
)
# Predict region types
region_types = tree.predict_region_type(X)
# Predict treatment effects (optional)
tauF, tauC = tree.predict_treatment_effects(X)- DivergenceTree: Direct optimization of joint treatment effects. Better when you want a single unified model and have sufficient data for tuning.
- TwoStepDivergenceTree: Separates effect estimation from classification. Better when you need interpretable treatment effect estimates or want to leverage causal forest's robustness.
examples/basic.py: Complete workflow with data generation, hyperparameter tuning, and visualization.
simulations/comparison/simulate.py: Compare both methods on the same dataset. Four independent steps:
- Generate and save data
- Run DivergenceTree, save results
- Run TwoStepDivergenceTree, save results
- Compare results and visualize
fit(X, T, YF, YC): Train the tree on datapredict_region_type(X): Predict region types (1-4) for new observationsleaf_effects(): Get treatment effects for each leaf node
fit(X, T, YF, YC, ...): Train the two-step modelpredict_region_type(X): Predict region types (1-4) for new observationspredict_treatment_effects(X): Predict τF and τC for new observationsleaf_effects(): Get leaf effect summary