Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions dowhy/causal_identifier/auto_identifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,14 +942,15 @@ def identify_generalized_adjustment_set(


def identify_mediation(graph: nx.DiGraph, action_nodes: List[str], outcome_nodes: List[str]):
"""Find a valid mediator if it exists.
"""Find all valid mediators between action and outcome nodes.

Currently only supports a single variable mediator set.
Returns a list of all variables that lie on a directed path from
action_nodes to outcome_nodes (each individually blocks at least one
such path when conditioned on).
"""
mediation_var = None
mediation_vars = []
mediation_paths = get_all_directed_paths(graph, action_nodes, outcome_nodes)
eligible_variables = get_descendants(graph, action_nodes) - set(outcome_nodes)
# For simplicity, assuming a one-variable mediation set
for candidate_var in eligible_variables:
is_valid_mediation = check_valid_mediation_set(
graph,
Expand All @@ -960,9 +961,9 @@ def identify_mediation(graph: nx.DiGraph, action_nodes: List[str], outcome_nodes
)
logger.debug("Candidate mediation set: {0}, on_mediating_path: {1}".format(candidate_var, is_valid_mediation))
if is_valid_mediation:
mediation_var = candidate_var
break
return parse_state(mediation_var)
mediation_vars.append(candidate_var)
# Sort for deterministic output — eligible_variables is a set.
return sorted(mediation_vars)


def identify_mediation_first_stage_confounders(
Expand Down
21 changes: 21 additions & 0 deletions tests/causal_identifiers/example_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,27 @@
maximal_adjustment_sets=[{"W"}],
direct_maximal_adjustment_sets=[{"W", "M"}],
),
# Treatment node must be named "X" and outcome "Y" to match the convention
# hardcoded in IdentificationTestGraphSolution (base.py) and all parametrized
# backdoor tests. The graph encodes D->Y, D->M1->Y, D->M2->Y from issue #1334
# — here renamed to X->Y, X->M1->Y, X->M2->Y for fixture compatibility.
"parallel-mediators": dict(
graph_str="""graph[directed 1 node[id "X" label "X"]
node[id "Y" label "Y"]
node[id "M1" label "M1"]
node[id "M2" label "M2"]
edge[source "X" target "Y"]
edge[source "X" target "M1"]
edge[source "X" target "M2"]
edge[source "M1" target "Y"]
edge[source "M2" target "Y"]]
""",
observed_variables=["X", "Y", "M1", "M2"],
biased_sets=[{"M1"}, {"M2"}, {"M1", "M2"}],
minimal_adjustment_sets=[set()],
maximal_adjustment_sets=[set()],
direct_maximal_adjustment_sets=[{"M1", "M2"}],
),
"mediator-with-conf": dict(
graph_str="""graph[directed 1 node[id "X" label "X"]
node[id "Y" label "Y"]
Expand Down
42 changes: 42 additions & 0 deletions tests/test_causal_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import dowhy
import dowhy.datasets
from dowhy import CausalModel
from dowhy.causal_identifier.auto_identifier import identify_mediation
from dowhy.graph import *
from dowhy.utils.graph_operations import daggity_to_dot

Expand Down Expand Up @@ -125,3 +126,44 @@ def test_has_path(self):
assert has_directed_path(self.nx_graph, ["X0", "X1", "X2"], ["y", "v0"])
assert not has_directed_path(self.nx_graph, [], ["y"])
assert not has_directed_path(self.nx_graph, ["X0", "X1", "X2"], ["y", "v0", "Z0"])


def test_identify_mediation_single_mediator():
"""Regression: single mediator case still returns exactly one mediator."""
graph = nx.DiGraph([("X", "M"), ("M", "Y"), ("X", "Y")])
mediators = identify_mediation(graph, ["X"], ["Y"])
assert mediators == ["M"]


def test_identify_mediation_parallel_mediators():
"""identify_mediation returns all valid mediators, not just the first.

Regression test for https://github.com/py-why/dowhy/issues/1334
"""
graph = nx.DiGraph([("D", "M1"), ("D", "M2"), ("D", "Y"), ("M1", "Y"), ("M2", "Y")])
mediators = identify_mediation(graph, ["D"], ["Y"])
assert set(mediators) == {"M1", "M2"}


def test_identify_mediation_no_mediator():
"""No mediator exists when there is only a direct path."""
graph = nx.DiGraph([("X", "Y")])
mediators = identify_mediation(graph, ["X"], ["Y"])
assert mediators == []


def test_nie_with_parallel_mediators():
"""End-to-end: NIE estimand includes all parallel mediators.

Regression test for https://github.com/py-why/dowhy/issues/1334
"""
graph = nx.DiGraph([("D", "M1"), ("D", "M2"), ("D", "Y"), ("M1", "Y"), ("M2", "Y")])
vertices = ["D", "M1", "M2", "Y"]
model = CausalModel(
data=pd.DataFrame(columns=vertices),
graph=graph,
treatment="D",
outcome="Y",
)
estimand = model.identify_effect(estimand_type="nonparametric-nie")
assert set(estimand.get_mediator_variables()) == {"M1", "M2"}
Loading