Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pandas as pd
from sklearn.base import RegressorMixin

from causalpy.plot_utils import ResponseType
from causalpy.pymc_models import PyMCModel
from causalpy.reporting import EffectSummary
from causalpy.skl_models import create_causalpy_compatible_class
Expand Down Expand Up @@ -156,6 +157,7 @@ def effect_summary(
treated_unit: str | None = None,
period: Literal["intervention", "post", "comparison"] | None = None,
prefix: str = "Post-period",
response_type: ResponseType = "expectation",
**kwargs: Any,
) -> EffectSummary:
"""
Expand Down Expand Up @@ -192,6 +194,19 @@ def effect_summary(
prefix : str, optional
Prefix for prose generation (e.g., "During intervention", "Post-intervention").
Defaults to "Post-period".
response_type : {"expectation", "prediction"}, default="expectation"
Response type to compute effect sizes (ITS/SC only, ignored for DiD/RD/RKink):

- ``"expectation"``: Effect size HDI based on model expectation (μ).
Excludes observation noise, focusing on the systematic causal effect.
- ``"prediction"``: Effect size HDI based on posterior predictive (ŷ).
Includes observation noise, showing full predictive uncertainty.

Note: This parameter only affects experiments where the causal effect is
calculated as the difference between observed and predicted values
(ITS, Synthetic Control). For experiments where the effect is a model
coefficient (DiD, RD, RKink), the HDI is always computed from the
posterior of the coefficient and this parameter is ignored.

Returns
-------
Expand Down
63 changes: 54 additions & 9 deletions causalpy/experiments/diff_in_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
DataException,
FormulaException,
)
from causalpy.plot_utils import plot_xY
from causalpy.plot_utils import (
ResponseType,
_log_response_type_info_once,
add_hdi_annotation,
plot_xY,
)
from causalpy.pymc_models import LinearRegression, PyMCModel
from causalpy.reporting import (
EffectSummary,
Expand Down Expand Up @@ -335,21 +340,54 @@ def _causal_impact_summary_stat(self, round_to: int | None = None) -> str:
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"

def _bayesian_plot(
self, round_to: int | None = None, **kwargs: dict
self,
round_to: int | None = None,
response_type: ResponseType = "expectation",
show_hdi_annotation: bool = False,
**kwargs: dict,
) -> tuple[plt.Figure, plt.Axes]:
"""
Plot the results

:param round_to:
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
Parameters
----------
round_to : int, optional
Number of decimals used to round results. Defaults to 2.
Use None to return raw numbers.
response_type : {"expectation", "prediction"}, default="expectation"
The response type to display in the HDI band:

- ``"expectation"``: HDI of the model expectation (μ). This shows
uncertainty from model parameters only, excluding observation noise.
Results in narrower intervals that represent the uncertainty in
the expected value of the outcome.
- ``"prediction"``: HDI of the posterior predictive (ŷ). This includes
observation noise (σ) in addition to parameter uncertainty, resulting
in wider intervals that represent the full predictive uncertainty
for new observations.
show_hdi_annotation : bool, default=False
Whether to display a text annotation at the bottom of the figure
explaining what the HDI represents. Set to False to hide the annotation.
**kwargs : dict
Additional keyword arguments.

Returns
-------
tuple[plt.Figure, plt.Axes]
The matplotlib figure and axes.
"""
# Log HDI type info once per session
_log_response_type_info_once()

# Select the variable name based on response_type
var_name = "mu" if response_type == "expectation" else "y_hat"

def _plot_causal_impact_arrow(results, ax):
"""
draw a vertical arrow between `y_pred_counterfactual` and
`y_pred_counterfactual`
"""
# Calculate y values to plot the arrow between
# Calculate y values to plot the arrow between - always use mu for arrow position
y_pred_treatment = (
results.y_pred_treatment["posterior_predictive"]
.mu.isel({"obs_ind": 1})
Expand Down Expand Up @@ -409,7 +447,7 @@ def _plot_causal_impact_arrow(results, ax):
time_points = self.x_pred_control[self.time_variable_name].values
h_line, h_patch = plot_xY(
time_points,
self.y_pred_control["posterior_predictive"].mu.isel(treated_units=0),
self.y_pred_control["posterior_predictive"][var_name].isel(treated_units=0),
ax=ax,
plot_hdi_kwargs={"color": "C0"},
label="Control group",
Expand All @@ -421,7 +459,9 @@ def _plot_causal_impact_arrow(results, ax):
time_points = self.x_pred_control[self.time_variable_name].values
h_line, h_patch = plot_xY(
time_points,
self.y_pred_treatment["posterior_predictive"].mu.isel(treated_units=0),
self.y_pred_treatment["posterior_predictive"][var_name].isel(
treated_units=0
),
ax=ax,
plot_hdi_kwargs={"color": "C1"},
label="Treatment group",
Expand All @@ -436,7 +476,7 @@ def _plot_causal_impact_arrow(results, ax):
y_pred_cf = az.extract(
self.y_pred_counterfactual,
group="posterior_predictive",
var_names="mu",
var_names=var_name,
)
# Select single unit data for plotting
y_pred_cf_single = y_pred_cf.isel(treated_units=0)
Expand All @@ -459,7 +499,7 @@ def _plot_causal_impact_arrow(results, ax):
else:
h_line, h_patch = plot_xY(
time_points,
self.y_pred_counterfactual.posterior_predictive.mu.isel(
self.y_pred_counterfactual.posterior_predictive[var_name].isel(
treated_units=0
),
ax=ax,
Expand All @@ -482,6 +522,11 @@ def _plot_causal_impact_arrow(results, ax):
labels=labels,
fontsize=LEGEND_FONT_SIZE,
)

# Add HDI type annotation to the title
if show_hdi_annotation:
add_hdi_annotation(ax, response_type)

return fig, ax

def _ols_plot(
Expand Down
Loading
Loading