From 9b413481b6dd4c735939f2c74dfc127b2d4e8a15 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 19 Mar 2026 23:46:13 +0000 Subject: [PATCH] fix: forward distance_metric_params from kwargs to NearestNeighbors The DistanceMatchingEstimator.__init__ accepted **kwargs for extra distance metric params (V, VI, p, w) but never stored them, because getattr(self, param_name) returned None for params that are only in kwargs. The base CausalEstimator.__init__ uses **_ so they were silently dropped. Additionally, NearestNeighbors requires these params via its metric_params= keyword argument, not as top-level kwargs. Fix both issues: - Build distance_metric_params dict directly from kwargs - Pass it as metric_params=... to NearestNeighbors Add regression test for Mahalanobis distance with V matrix. Closes #1390 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../distance_matching_estimator.py | 19 ++-- .../test_distance_matching_estimator.py | 102 ++++++++++++++++++ 2 files changed, 112 insertions(+), 9 deletions(-) create mode 100644 tests/causal_estimators/test_distance_matching_estimator.py diff --git a/dowhy/causal_estimators/distance_matching_estimator.py b/dowhy/causal_estimators/distance_matching_estimator.py index 2a6d1ebfc7..6fe7c2903b 100644 --- a/dowhy/causal_estimators/distance_matching_estimator.py +++ b/dowhy/causal_estimators/distance_matching_estimator.py @@ -91,12 +91,13 @@ def __init__( self.distance_metric = distance_metric # Dictionary of any user-provided params for the distance metric - # that will be passed to sklearn nearestneighbors - self.distance_metric_params = {} - for param_name in self.Valid_Dist_Metric_Params: - param_val = getattr(self, param_name, None) - if param_val is not None: - self.distance_metric_params[param_name] = param_val + # that will be passed to sklearn nearestneighbors. + # Pull values from kwargs directly since the base class discards them. + self.distance_metric_params = { + param_name: kwargs[param_name] + for param_name in self.Valid_Dist_Metric_Params + if param_name in kwargs + } self.logger.info("INFO: Using Distance Matching Estimator") @@ -200,7 +201,7 @@ def estimate_effect( n_neighbors=self.num_matches_per_unit, metric=self.distance_metric, algorithm="ball_tree", - **self.distance_metric_params, + metric_params=self.distance_metric_params if self.distance_metric_params else None, ).fit(control[self._observed_common_causes.columns].values) distances, indices = control_neighbors.kneighbors(treated[self._observed_common_causes.columns].values) self.logger.debug("distances:") @@ -238,7 +239,7 @@ def estimate_effect( n_neighbors=self.num_matches_per_unit, metric=self.distance_metric, algorithm="ball_tree", - **self.distance_metric_params, + metric_params=self.distance_metric_params if self.distance_metric_params else None, ).fit(control[self._observed_common_causes.columns].values) distances, indices = control_neighbors.kneighbors( treated[self._observed_common_causes.columns].values @@ -267,7 +268,7 @@ def estimate_effect( n_neighbors=self.num_matches_per_unit, metric=self.distance_metric, algorithm="ball_tree", - **self.distance_metric_params, + metric_params=self.distance_metric_params if self.distance_metric_params else None, ).fit(treated[self._observed_common_causes.columns].values) distances, indices = treated_neighbors.kneighbors(control[self._observed_common_causes.columns].values) diff --git a/tests/causal_estimators/test_distance_matching_estimator.py b/tests/causal_estimators/test_distance_matching_estimator.py new file mode 100644 index 0000000000..4a12294689 --- /dev/null +++ b/tests/causal_estimators/test_distance_matching_estimator.py @@ -0,0 +1,102 @@ +import numpy as np +import pytest + +from dowhy.causal_estimators.distance_matching_estimator import DistanceMatchingEstimator + +from .base import SimpleEstimator + + +@pytest.mark.usefixtures("fixed_seed") +class TestDistanceMatchingEstimator: + @pytest.mark.parametrize( + [ + "error_tolerance", + "Estimator", + "num_common_causes", + "num_instruments", + "num_effect_modifiers", + "num_treatments", + "treatment_is_binary", + "outcome_is_binary", + "identifier_method", + ], + [ + ( + 0.3, + DistanceMatchingEstimator, + [1, 2], + [0], + [0], + [1], + [True], + [False], + "backdoor", + ), + ], + ) + def test_average_treatment_effect( + self, + error_tolerance, + Estimator, + num_common_causes, + num_instruments, + num_effect_modifiers, + num_treatments, + treatment_is_binary, + outcome_is_binary, + identifier_method, + ): + estimator_tester = SimpleEstimator(error_tolerance, Estimator, identifier_method=identifier_method) + estimator_tester.average_treatment_effect_testsuite( + num_common_causes=num_common_causes, + num_instruments=num_instruments, + num_effect_modifiers=num_effect_modifiers, + num_treatments=num_treatments, + treatment_is_binary=treatment_is_binary, + outcome_is_binary=outcome_is_binary, + ) + + def test_distance_metric_params_passed_to_estimator(self): + """Regression test for https://github.com/py-why/dowhy/issues/1390. + + distance_metric_params such as V (for Mahalanobis) must be forwarded + from method_params to NearestNeighbors, not silently dropped. + """ + import dowhy.datasets + from dowhy import EstimandType, identify_effect_auto + from dowhy.graph import build_graph_from_str + + data = dowhy.datasets.linear_dataset( + beta=10, + num_common_causes=2, + num_instruments=0, + num_effect_modifiers=0, + num_treatments=1, + num_samples=500, + treatment_is_binary=True, + ) + graph = build_graph_from_str(data["gml_graph"]) + observed_nodes = list(data["df"].columns) + identified_estimand = identify_effect_auto( + graph, + data["treatment_name"], + data["outcome_name"], + observed_nodes, + EstimandType.NONPARAMETRIC_ATE, + ) + common_causes = data["df"][data["common_causes_names"]].values + V = np.cov(common_causes.T) + + estimator = DistanceMatchingEstimator( + identified_estimand=identified_estimand, + distance_metric="mahalanobis", + V=V, + ) + assert estimator.distance_metric_params == {"V": V}, ( + "distance_metric_params should capture V from kwargs" + ) + + # Also verify that fit + estimate_effect works end-to-end without error. + estimator.fit(data["df"]) + estimate = estimator.estimate_effect(data["df"], target_units="att") + assert estimate.value is not None