Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dowhy/causal_estimators/distance_matching_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(
# that will be passed to sklearn nearestneighbors
self.distance_metric_params = {}
for param_name in self.Valid_Dist_Metric_Params:
# Pull values from kwargs directly since the base class discards them.
param_val = kwargs.get(param_name, None)
if param_val is not None:
self.distance_metric_params[param_name] = param_val
Expand Down Expand Up @@ -200,7 +201,7 @@ def estimate_effect(
n_neighbors=self.num_matches_per_unit,
metric=self.distance_metric,
algorithm="ball_tree",
**self.distance_metric_params,
metric_params=self.distance_metric_params if self.distance_metric_params else None,
).fit(control[self._observed_common_causes.columns].values)
distances, indices = control_neighbors.kneighbors(treated[self._observed_common_causes.columns].values)
self.logger.debug("distances:")
Expand Down Expand Up @@ -238,7 +239,7 @@ def estimate_effect(
n_neighbors=self.num_matches_per_unit,
metric=self.distance_metric,
algorithm="ball_tree",
**self.distance_metric_params,
metric_params=self.distance_metric_params if self.distance_metric_params else None,
).fit(control[self._observed_common_causes.columns].values)
distances, indices = control_neighbors.kneighbors(
treated[self._observed_common_causes.columns].values
Expand Down Expand Up @@ -267,7 +268,7 @@ def estimate_effect(
n_neighbors=self.num_matches_per_unit,
metric=self.distance_metric,
algorithm="ball_tree",
**self.distance_metric_params,
metric_params=self.distance_metric_params if self.distance_metric_params else None,
).fit(treated[self._observed_common_causes.columns].values)
distances, indices = treated_neighbors.kneighbors(control[self._observed_common_causes.columns].values)

Expand Down
102 changes: 102 additions & 0 deletions tests/causal_estimators/test_distance_matching_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np
import pytest

from dowhy.causal_estimators.distance_matching_estimator import DistanceMatchingEstimator

from .base import SimpleEstimator


@pytest.mark.usefixtures("fixed_seed")
class TestDistanceMatchingEstimator:
@pytest.mark.parametrize(
[
"error_tolerance",
"Estimator",
"num_common_causes",
"num_instruments",
"num_effect_modifiers",
"num_treatments",
"treatment_is_binary",
"outcome_is_binary",
"identifier_method",
],
[
(
0.3,
DistanceMatchingEstimator,
[1, 2],
[0],
[0],
[1],
[True],
[False],
"backdoor",
),
],
)
def test_average_treatment_effect(
self,
error_tolerance,
Estimator,
num_common_causes,
num_instruments,
num_effect_modifiers,
num_treatments,
treatment_is_binary,
outcome_is_binary,
identifier_method,
):
estimator_tester = SimpleEstimator(error_tolerance, Estimator, identifier_method=identifier_method)
estimator_tester.average_treatment_effect_testsuite(
num_common_causes=num_common_causes,
num_instruments=num_instruments,
num_effect_modifiers=num_effect_modifiers,
num_treatments=num_treatments,
treatment_is_binary=treatment_is_binary,
outcome_is_binary=outcome_is_binary,
)

def test_distance_metric_params_passed_to_estimator(self):
"""Regression test for https://github.com/py-why/dowhy/issues/1390.

distance_metric_params such as V (for Mahalanobis) must be forwarded
from method_params to NearestNeighbors, not silently dropped.
"""
import dowhy.datasets
from dowhy import EstimandType, identify_effect_auto
from dowhy.graph import build_graph_from_str

data = dowhy.datasets.linear_dataset(
beta=10,
num_common_causes=2,
num_instruments=0,
num_effect_modifiers=0,
num_treatments=1,
num_samples=500,
treatment_is_binary=True,
)
graph = build_graph_from_str(data["gml_graph"])
observed_nodes = list(data["df"].columns)
identified_estimand = identify_effect_auto(
graph,
data["treatment_name"],
data["outcome_name"],
observed_nodes,
EstimandType.NONPARAMETRIC_ATE,
)
common_causes = data["df"][data["common_causes_names"]].values
V = np.cov(common_causes.T)

estimator = DistanceMatchingEstimator(
identified_estimand=identified_estimand,
distance_metric="mahalanobis",
V=V,
)
assert estimator.distance_metric_params == {"V": V}, (
"distance_metric_params should capture V from kwargs"
)

# Also verify that fit + estimate_effect works end-to-end without error.
estimator.fit(data["df"])
estimate = estimator.estimate_effect(data["df"], target_units="att")
assert estimate.value is not None
Loading