Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 81 additions & 36 deletions dowhy/causal_refuters/add_unobserved_common_cause.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pandas as pd
import scipy.stats
import statsmodels.api as sm
from joblib import Parallel, delayed
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from tqdm.auto import tqdm
Expand Down Expand Up @@ -72,6 +73,8 @@ def __init__(self, *args, **kwargs):
:param alpha_s_estimator_param_list: list of dictionaries with parameters for finding alpha_s. (relevant only for non-parametric-partial-R2 simulation method)
:param g_s_estimator_list: list of estimator objects for finding g_s. These objects should have fit() and predict() functions implemented. (relevant only for non-parametric-partial-R2 simulation method)
:param g_s_estimator_param_list: list of dictionaries with parameters for tuning respective estimators in "g_s_estimator_list". The order of the dictionaries in the list should be consistent with the estimator objects order in "g_s_estimator_list". (relevant only for non-parametric-partial-R2 simulation method)
:param n_jobs: The maximum number of concurrently running jobs. If -1 all CPUs are used. If 1 is given, no parallel computing code is used at all (this is the default). (relevant only for direct-simulation method)
:param verbose: The verbosity level: if non zero, progress messages are printed. Above 50, the output is sent to stdout. The frequency of the messages increases with the verbosity level. If it more than 10, all iterations are reported. The default is 0. (relevant only for direct-simulation method)
"""
super().__init__(*args, **kwargs)
self.simulation_method = kwargs["simulation_method"] if "simulation_method" in kwargs else "direct-simulation"
Expand Down Expand Up @@ -182,6 +185,8 @@ def refute_estimate(self, show_progress_bar=False):
self.frac_strength_outcome,
self.plotmethod,
show_progress_bar,
self._n_jobs,
self._verbose,
)
refute.add_refuter(self)
return refute
Expand Down Expand Up @@ -780,6 +785,44 @@ def sensitivity_e_value(
return analyzer


def _simulate_confounders_effect_once(
data: pd.DataFrame,
orig_data: pd.DataFrame,
target_estimand: IdentifiedEstimand,
estimate: CausalEstimate,
treatment_name: str,
outcome_name: str,
confounders_effect_on_treatment: str,
confounders_effect_on_outcome: str,
kappa_t_value: float,
kappa_y_value: float,
) -> float:
"""Execute one simulation with specific kappa_t and kappa_y values."""
new_data = _include_confounders_effect(
data,
orig_data,
confounders_effect_on_treatment,
treatment_name,
kappa_t_value,
confounders_effect_on_outcome,
outcome_name,
kappa_y_value,
)
new_estimator = estimate.estimator.get_new_estimator_object(target_estimand)
new_estimator.fit(
new_data,
effect_modifier_names=estimate.estimator._effect_modifier_names,
**new_estimator._fit_params if hasattr(new_estimator, "_fit_params") else {},
)
new_effect = new_estimator.estimate_effect(
new_data,
control_value=estimate.control_value,
treatment_value=estimate.treatment_value,
target_units=estimate.estimator._target_units,
)
return new_effect.value


def sensitivity_simulation(
data: pd.DataFrame,
target_estimand: IdentifiedEstimand,
Expand All @@ -794,6 +837,8 @@ def sensitivity_simulation(
frac_strength_outcome: float = 1.0,
plotmethod: Optional[str] = None,
show_progress_bar=False,
n_jobs: int = 1,
verbose: int = 0,
**_,
) -> CausalRefutation:
"""
Expand Down Expand Up @@ -867,45 +912,45 @@ def sensitivity_simulation(
# Get a 2D matrix of values
# x,y = np.meshgrid(self.kappa_t, self.kappa_y) # x,y are both MxN

results_matrix = np.random.rand(len(kappa_t), len(kappa_y)) # Matrix to hold all the results of NxM
results_matrix = np.zeros((len(kappa_t), len(kappa_y))) # Matrix to hold all the results of NxM
orig_data = copy.deepcopy(data)

for i in tqdm(
range(len(kappa_t)),
colour=CausalRefuter.PROGRESS_BAR_COLOR,
disable=not show_progress_bar,
desc="Refuting Estimates: ",
):
for j in range(len(kappa_y)):
new_data = _include_confounders_effect(
data,
orig_data,
confounders_effect_on_treatment,
treatment_name,
kappa_t[i],
confounders_effect_on_outcome,
outcome_name,
kappa_y[j],
)
new_estimator = estimate.estimator.get_new_estimator_object(target_estimand)
new_estimator.fit(
new_data,
effect_modifier_names=estimate.estimator._effect_modifier_names,
**new_estimator._fit_params if hasattr(new_estimator, "_fit_params") else {},
)
new_effect = new_estimator.estimate_effect(
new_data,
control_value=estimate.control_value,
treatment_value=estimate.treatment_value,
target_units=estimate.estimator._target_units,
)
refute = CausalRefutation(
estimate.value,
new_effect.value,
refutation_type="Refute: Add an Unobserved Common Cause",
)
results_matrix[i][j] = refute.new_effect # Populate the results
# Create list of parameter combinations for parallel execution
param_combinations = [
(i, j, kappa_t[i], kappa_y[j]) for i in range(len(kappa_t)) for j in range(len(kappa_y))
]

# Run simulations in parallel
results = Parallel(n_jobs=n_jobs, verbose=verbose)(
delayed(_simulate_confounders_effect_once)(
data,
orig_data,
target_estimand,
estimate,
treatment_name,
outcome_name,
confounders_effect_on_treatment,
confounders_effect_on_outcome,
kappa_t_val,
kappa_y_val,
)
for i, j, kappa_t_val, kappa_y_val in tqdm(
param_combinations,
colour=CausalRefuter.PROGRESS_BAR_COLOR,
disable=not show_progress_bar,
desc="Refuting Estimates: ",
)
)

# Populate the results matrix
for (i, j, _, _), result in zip(param_combinations, results):
results_matrix[i][j] = result

refute = CausalRefutation(
estimate.value,
results[-1], # Use last result as representative
refutation_type="Refute: Add an Unobserved Common Cause",
)
refute.new_effect_array = results_matrix
refute.new_effect = (np.min(results_matrix), np.max(results_matrix))
# Store the values into the refute object
Expand Down
Loading
Loading