Skip to content

Commit 4d2442b

Browse files
Warn when staggered DiD ATTs lack untreated-period support (#949)
* Warn when staggered DiD ATTs lack untreated-period support Detect calendar periods with no untreated observations, emit a clear UserWarning, and mark non-identified ATT(g, t) and ATT(e) cells with identified=False and NaN estimates instead of silently returning biased values when never-treated units are absent. Closes #938 Co-authored-by: Cursor <cursoragent@cursor.com> * Mask non-identified cells in get_plot_data_bayesian recompute path When hdi_prob differs from the stored aggregation value, pass the recomputed event-time DataFrame through _mark_non_identified_att_rows() so non-identified ATTs stay NaN. Add regression test for the non-default HDI path on a no-never-treated panel. Co-authored-by: Cursor <cursoragent@cursor.com> * Add tests for non-identified ATT coverage gaps Cover effect_summary filtering and _mark_non_identified_att_rows edge cases so Codecov patch coverage passes on the identification changes. Related: #950 (local patch coverage in prek) Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent f85604b commit 4d2442b

4 files changed

Lines changed: 304 additions & 7 deletions

File tree

causalpy/experiments/staggered_did.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
different units receive treatment at different times.
2020
"""
2121

22+
import warnings
2223
from typing import Any, Literal
2324

2425
import numpy as np
@@ -58,6 +59,13 @@ class StaggeredDifferenceInDifferences(BaseExperiment):
5859
units would have followed parallel outcome trajectories.
5960
3. **No anticipation**: Units do not change their behavior in anticipation
6061
of future treatment.
62+
4. **Untreated support at each calendar period**: The time fixed effect
63+
:math:`\\gamma_t` for calendar period :math:`t` is identified only if at
64+
least one unit is untreated in that period. Without never-treated units,
65+
post-treatment effects for the last-treated cohort (and any calendar
66+
periods where every unit is already treated) are not identified. CausalPy
67+
warns when this condition fails and marks the affected ``ATT(g, t)`` and
68+
``ATT(e)`` cells as non-identified in the output tables.
6169
6270
Parameters
6371
----------
@@ -94,8 +102,14 @@ class StaggeredDifferenceInDifferences(BaseExperiment):
94102
and tau_hat (treatment effect) columns.
95103
att_group_time_ : pd.DataFrame
96104
Group-time ATT estimates: ATT(g, t) for each cohort g and calendar time t.
105+
Includes an ``identified`` column; non-identified cells have ``NaN`` estimates.
97106
att_event_time_ : pd.DataFrame
98107
Event-time ATT estimates: ATT(e) for each event-time e = t - G.
108+
Includes an ``identified`` column; non-identified cells have ``NaN`` estimates.
109+
non_identified_periods_ : set
110+
Calendar periods with no untreated observations.
111+
non_identified_cohorts_ : set
112+
Treatment cohorts with at least one non-identified post-treatment ATT(g, t).
99113
100114
Notes
101115
-----
@@ -185,6 +199,9 @@ def __init__(
185199
# Step 3: Identify untreated observations (training set)
186200
self._identify_untreated_observations()
187201

202+
# Step 3b: Check calendar-period identification support
203+
self._check_att_identification()
204+
188205
# Step 4: Build design matrices
189206
self._build_design_matrices()
190207

@@ -319,6 +336,92 @@ def _identify_untreated_observations(self) -> None:
319336
"Ensure there are never-treated units or pre-treatment periods."
320337
)
321338

339+
def _get_periods_without_untreated_support(self) -> set[Any]:
340+
"""Return calendar periods with zero untreated observations."""
341+
untreated_periods = set(
342+
self.data.loc[self.data["_is_untreated"], self.time_variable_name].unique()
343+
)
344+
all_periods = set(self.data[self.time_variable_name].unique())
345+
return all_periods - untreated_periods
346+
347+
def _get_non_identified_cohorts(self, periods: set[Any]) -> set[Any]:
348+
"""Return cohorts with post-treatment cells in non-identified periods."""
349+
non_identified_cohorts: set[Any] = set()
350+
for cohort in self.cohorts:
351+
for period in periods:
352+
if period >= cohort:
353+
non_identified_cohorts.add(cohort)
354+
break
355+
return non_identified_cohorts
356+
357+
def _check_att_identification(self) -> None:
358+
"""Detect non-identified ATT cells and warn when untreated support is missing."""
359+
self.non_identified_periods_ = self._get_periods_without_untreated_support()
360+
self.non_identified_cohorts_ = self._get_non_identified_cohorts(
361+
self.non_identified_periods_
362+
)
363+
364+
if not self.non_identified_periods_:
365+
return
366+
367+
periods_str = ", ".join(str(p) for p in sorted(self.non_identified_periods_))
368+
cohorts_str = ", ".join(str(c) for c in sorted(self.non_identified_cohorts_))
369+
warnings.warn(
370+
"No untreated observations in calendar period(s) "
371+
f"{{{periods_str}}}; treatment effects for cohort(s) "
372+
f"{{{cohorts_str}}} are not identified at the affected post-treatment "
373+
"cells. Provide never-treated units or restrict the event window. "
374+
"Non-identified ATT(g, t) and ATT(e) cells are marked in the output "
375+
"tables (identified=False) with NaN estimates.",
376+
UserWarning,
377+
stacklevel=2,
378+
)
379+
380+
def _is_calendar_period_identified(self, period: Any) -> bool:
381+
"""Return whether calendar period ``period`` has untreated support."""
382+
return period not in self.non_identified_periods_
383+
384+
def _is_event_time_att_identified(self, event_time: int) -> bool:
385+
"""Return whether aggregated ATT(e) is identified."""
386+
for cohort in self.cohorts:
387+
period = cohort + event_time
388+
has_contributing_obs = (
389+
(self.data["G"] == cohort)
390+
& (self.data[self.time_variable_name] == period)
391+
& (self.data["event_time"] == event_time)
392+
).any()
393+
if has_contributing_obs and not self._is_calendar_period_identified(period):
394+
return False
395+
return True
396+
397+
def _mark_non_identified_att_rows(self, att_df: pd.DataFrame) -> pd.DataFrame:
398+
"""Add ``identified`` column and mask non-identified point estimates."""
399+
if len(att_df) == 0:
400+
att_df = att_df.copy()
401+
att_df["identified"] = pd.Series(dtype=bool)
402+
return att_df
403+
404+
att_df = att_df.copy()
405+
if "cohort" in att_df.columns and "time" in att_df.columns:
406+
att_df["identified"] = att_df["time"].map(
407+
self._is_calendar_period_identified
408+
)
409+
elif "event_time" in att_df.columns:
410+
att_df["identified"] = att_df["event_time"].apply(
411+
lambda e: self._is_event_time_att_identified(int(e))
412+
)
413+
else:
414+
att_df["identified"] = True
415+
416+
value_columns = [
417+
col
418+
for col in ("att", "att_lower", "att_upper", "att_std")
419+
if col in att_df.columns
420+
]
421+
for col in value_columns:
422+
att_df.loc[~att_df["identified"], col] = np.nan
423+
return att_df
424+
322425
def _build_design_matrices(self) -> None:
323426
"""Build design matrices using patsy."""
324427
# Build design matrix for the full data
@@ -499,7 +602,9 @@ def _aggregate_effects_bayesian(
499602
"att_upper": float(np.percentile(tau_gt, upper_pct)),
500603
}
501604
)
502-
self.att_group_time_ = pd.DataFrame(att_gt_rows)
605+
self.att_group_time_ = self._mark_non_identified_att_rows(
606+
pd.DataFrame(att_gt_rows)
607+
)
503608

504609
# --- Event-time ATTs (including pre-treatment placebo) ---
505610
att_et_rows: list[dict] = []
@@ -563,7 +668,9 @@ def _aggregate_effects_bayesian(
563668
}
564669
)
565670

566-
self.att_event_time_ = pd.DataFrame(att_et_rows)
671+
self.att_event_time_ = self._mark_non_identified_att_rows(
672+
pd.DataFrame(att_et_rows)
673+
)
567674

568675
def _aggregate_effects_ols(
569676
self, treated_data: pd.DataFrame, pretreatment_data: pd.DataFrame
@@ -585,7 +692,7 @@ def _aggregate_effects_ols(
585692
.reset_index()
586693
)
587694
att_gt.columns = ["cohort", "time", "att", "att_std", "n_obs"]
588-
self.att_group_time_ = att_gt
695+
self.att_group_time_ = self._mark_non_identified_att_rows(att_gt)
589696

590697
# --- Event-time ATTs (including pre-treatment placebo) ---
591698
# Compute tau_hat for pre-treatment observations (residuals)
@@ -613,7 +720,7 @@ def _aggregate_effects_ols(
613720
)
614721
att_et.columns = ["event_time", "att", "att_std", "n_obs"]
615722
att_et["event_time"] = att_et["event_time"].astype(int)
616-
self.att_event_time_ = att_et
723+
self.att_event_time_ = self._mark_non_identified_att_rows(att_et)
617724

618725
def summary(
619726
self, round_to: int | None = 2, include_group_time: bool = False
@@ -1506,7 +1613,7 @@ def get_plot_data_bayesian(self, hdi_prob: float = HDI_PROB) -> pd.DataFrame:
15061613
}
15071614
)
15081615

1509-
return pd.DataFrame(att_et_rows)
1616+
return self._mark_non_identified_att_rows(pd.DataFrame(att_et_rows))
15101617

15111618
def get_plot_data_ols(self) -> pd.DataFrame:
15121619
"""Get plotting data for OLS model.

causalpy/reporting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ def _effect_summary_staggered_did(
382382
# Separate pre-treatment (placebo) and post-treatment effects
383383
pre_treatment = att_et[att_et["event_time"] < 0]
384384
post_treatment = att_et[att_et["event_time"] >= 0]
385+
if "identified" in post_treatment.columns:
386+
post_treatment = post_treatment[post_treatment["identified"]]
387+
if "identified" in pre_treatment.columns:
388+
pre_treatment = pre_treatment[pre_treatment["identified"]]
385389

386390
# Build summary table with all event-time effects
387391
table = att_et.copy()

causalpy/tests/test_staggered_did.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Tests for StaggeredDifferenceInDifferences experiment class.
1616
"""
1717

18+
import warnings
19+
1820
import numpy as np
1921
import pandas as pd
2022
import pytest
@@ -703,6 +705,183 @@ def test_staggered_did_group_time_att_structure():
703705
assert len(result.att_group_time_["cohort"].unique()) >= 2
704706

705707

708+
def _no_never_treated_staggered_did_df() -> pd.DataFrame:
709+
"""Return a staggered panel where every unit is eventually treated."""
710+
return generate_staggered_did_data(
711+
n_units=20,
712+
n_time_periods=10,
713+
treatment_cohorts={3: 10, 7: 10},
714+
seed=42,
715+
)
716+
717+
718+
@pytest.mark.parametrize(
719+
"model_factory",
720+
[
721+
pytest.param(lambda: LinearRegression(), id="ols"),
722+
pytest.param(
723+
lambda: cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
724+
id="pymc",
725+
),
726+
],
727+
)
728+
def test_staggered_did_warns_non_identified_without_never_treated(
729+
model_factory, mock_pymc_sample
730+
):
731+
"""No never-treated units should warn about non-identified post-treatment ATTs."""
732+
df = _no_never_treated_staggered_did_df()
733+
734+
with pytest.warns(
735+
UserWarning, match="No untreated observations in calendar period"
736+
):
737+
result = cp.StaggeredDifferenceInDifferences(
738+
df,
739+
formula="y ~ 1 + C(unit) + C(time)",
740+
unit_variable_name="unit",
741+
time_variable_name="time",
742+
treated_variable_name="treated",
743+
treatment_time_variable_name="treatment_time",
744+
model=model_factory(),
745+
)
746+
747+
assert result.non_identified_periods_
748+
assert 7 in result.non_identified_periods_
749+
assert 7 in result.non_identified_cohorts_
750+
751+
assert "identified" in result.att_group_time_.columns
752+
non_identified_gt = result.att_group_time_[~result.att_group_time_["identified"]]
753+
assert len(non_identified_gt) > 0
754+
assert non_identified_gt["att"].isna().all()
755+
756+
assert "identified" in result.att_event_time_.columns
757+
non_identified_et = result.att_event_time_[
758+
(result.att_event_time_["event_time"] >= 0)
759+
& ~result.att_event_time_["identified"]
760+
]
761+
assert len(non_identified_et) > 0
762+
assert non_identified_et["att"].isna().all()
763+
764+
765+
def test_staggered_did_get_plot_data_bayesian_masks_non_identified_on_recompute(
766+
mock_pymc_sample,
767+
):
768+
"""Non-default hdi_prob recompute path must still mask non-identified cells."""
769+
df = _no_never_treated_staggered_did_df()
770+
771+
with pytest.warns(
772+
UserWarning, match="No untreated observations in calendar period"
773+
):
774+
result = cp.StaggeredDifferenceInDifferences(
775+
df,
776+
formula="y ~ 1 + C(unit) + C(time)",
777+
unit_variable_name="unit",
778+
time_variable_name="time",
779+
treated_variable_name="treated",
780+
treatment_time_variable_name="treatment_time",
781+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
782+
)
783+
784+
plot_data = result.get_plot_data_bayesian(hdi_prob=0.80)
785+
786+
assert "identified" in plot_data.columns
787+
non_identified_post = plot_data[
788+
(plot_data["event_time"] >= 0) & ~plot_data["identified"]
789+
]
790+
assert len(non_identified_post) > 0
791+
assert non_identified_post["att"].isna().all()
792+
assert non_identified_post["att_lower"].isna().all()
793+
assert non_identified_post["att_upper"].isna().all()
794+
795+
796+
def test_staggered_did_effect_summary_excludes_non_identified_cells():
797+
"""effect_summary prose should average only identified post-treatment ATTs."""
798+
df = _no_never_treated_staggered_did_df()
799+
800+
with pytest.warns(
801+
UserWarning, match="No untreated observations in calendar period"
802+
):
803+
result = cp.StaggeredDifferenceInDifferences(
804+
df,
805+
formula="y ~ 1 + C(unit) + C(time)",
806+
unit_variable_name="unit",
807+
time_variable_name="time",
808+
treated_variable_name="treated",
809+
treatment_time_variable_name="treatment_time",
810+
model=LinearRegression(),
811+
)
812+
813+
post_treatment = result.att_event_time_[result.att_event_time_["event_time"] >= 0]
814+
assert (~post_treatment["identified"]).any()
815+
816+
summary = result.effect_summary()
817+
assert "Staggered DiD" in summary.text
818+
assert isinstance(summary.table, pd.DataFrame)
819+
820+
821+
def test_staggered_did_mark_non_identified_att_rows_edge_cases():
822+
"""Cover empty and unknown-column paths in _mark_non_identified_att_rows."""
823+
df = _no_never_treated_staggered_did_df()
824+
825+
with pytest.warns(
826+
UserWarning, match="No untreated observations in calendar period"
827+
):
828+
result = cp.StaggeredDifferenceInDifferences(
829+
df,
830+
formula="y ~ 1 + C(unit) + C(time)",
831+
unit_variable_name="unit",
832+
time_variable_name="time",
833+
treated_variable_name="treated",
834+
treatment_time_variable_name="treatment_time",
835+
model=LinearRegression(),
836+
)
837+
838+
empty = result._mark_non_identified_att_rows(pd.DataFrame())
839+
assert list(empty.columns) == ["identified"]
840+
assert len(empty) == 0
841+
842+
unknown_cols = result._mark_non_identified_att_rows(pd.DataFrame({"x": [1]}))
843+
assert unknown_cols["identified"].all()
844+
845+
846+
@pytest.mark.parametrize(
847+
"model_factory",
848+
[
849+
pytest.param(lambda: LinearRegression(), id="ols"),
850+
pytest.param(
851+
lambda: cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
852+
id="pymc",
853+
),
854+
],
855+
)
856+
def test_staggered_did_all_identified_with_never_treated(
857+
model_factory, mock_pymc_sample
858+
):
859+
"""Never-treated units should identify all post-treatment ATT cells."""
860+
df = generate_staggered_did_data(
861+
n_units=30,
862+
n_time_periods=10,
863+
treatment_cohorts={3: 10, 7: 10},
864+
seed=42,
865+
)
866+
867+
with warnings.catch_warnings():
868+
warnings.simplefilter("error", UserWarning)
869+
result = cp.StaggeredDifferenceInDifferences(
870+
df,
871+
formula="y ~ 1 + C(unit) + C(time)",
872+
unit_variable_name="unit",
873+
time_variable_name="time",
874+
treated_variable_name="treated",
875+
treatment_time_variable_name="treatment_time",
876+
model=model_factory(),
877+
)
878+
879+
assert result.non_identified_periods_ == set()
880+
assert result.att_group_time_["identified"].all()
881+
post_treatment = result.att_event_time_[result.att_event_time_["event_time"] >= 0]
882+
assert post_treatment["identified"].all()
883+
884+
706885
def test_staggered_did_no_untreated_observations():
707886
"""Test that having no untreated observations raises DataException."""
708887
# All units treated from time 0

0 commit comments

Comments
 (0)