Skip to content

Commit 8342c7d

Browse files
committed
Address PR #826 review feedback (round 2)
Ben's review from 2026-04-23: PlaceboInTime (checks/placebo_in_time.py): - Add allow_overlap parameter (default False) that enforces non-overlap of pseudo-intervention windows, so random folds no longer violate the hierarchical model's exchangeability assumption by default. - Replace the "retry with a different random_seed" error path with a bounded greedy-selection retry loop (MAX_RANDOM_SELECTION_RETRIES=16) using deterministic sub-seeds derived from random_seed; failure message now names the knobs to relax (allow_overlap, min_gap, n_folds). - Warn when a pre-drawn numpy expected_effect_prior is shorter than the number of replications requested by the assurance simulation, documenting the cycling behaviour in _draw_expected_effect_samples. - Surface allow_overlap=True in __repr__ following the same "non-default only" pattern as selection_method. - Document how intervention_length falls back to data.index.max() - treatment_time when treatment_end_time is unset, shrinking the random-mode eligible window. OutcomeFalsification (checks/outcome_falsification.py): - Warn at run() when storing >= 3 fitted experiments, explaining that store_experiments=False keeps only summary statistics. - Rewrite __repr__ to hide default alpha and store_experiments flags. - Drop dead np.linalg.LinAlgError from the caught-exception tuple and the now-unused numpy import. Docs: - Wire its_place_in_time_analysis.ipynb into the ITS toctree so it stops being an orphan page. - Cross-link the notebook from sensitivity_checks.md under the "Where examples already exist" list. - Tag plot-only cells with hide-input and sampler-heavy cells with hide-output so the rendered page collapses non-essential chunks. Tests: - Pin allow_overlap default, the non-overlap invariant, the _windows_overlap helper for numeric and datetime indices, the allow_overlap opt-out, the bounded-retry reproducibility and exhaustion paths, and the expected-effect-prior cycling warning. - Pin the new OutcomeFalsification __repr__ pattern, the store_experiments memory warning at run() time, and its opt-out and below-threshold paths. - Expand the upstream-bug TODO in test_run_handles_failed_formula. Made-with: Cursor
1 parent f50caff commit 8342c7d

7 files changed

Lines changed: 561 additions & 52 deletions

File tree

causalpy/checks/outcome_falsification.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from dataclasses import dataclass
3636
from typing import Any
3737

38-
import numpy as np
3938
import pandas as pd
4039
from patsy import PatsyError
4140

@@ -50,6 +49,8 @@
5049

5150
logger = logging.getLogger(__name__)
5251

52+
_STORE_EXPERIMENTS_WARN_THRESHOLD = 3
53+
5354

5455
@dataclass
5556
class FalsificationResult:
@@ -101,7 +102,12 @@ class OutcomeFalsification:
101102
``InferenceData``), which lets users inspect posteriors but
102103
can be memory-heavy for many formulas. Set to ``False`` to
103104
keep only the summary statistics (``effect_mean``,
104-
``hdi_lower``, ``hdi_upper``).
105+
``hdi_lower``, ``hdi_upper``). A one-off warning is emitted at
106+
:meth:`run` when ``store_experiments=True`` and at least
107+
``3`` formulas are supplied, because the combined
108+
``InferenceData`` footprint of several fitted experiments can
109+
easily reach hundreds of MB on larger datasets
110+
(e.g. :class:`PiecewiseITS`).
105111
106112
Examples
107113
--------
@@ -246,6 +252,20 @@ def run(
246252
"""
247253
self.validate(experiment)
248254

255+
if (
256+
self.store_experiments
257+
and len(self.formulas) >= _STORE_EXPERIMENTS_WARN_THRESHOLD
258+
):
259+
warnings.warn(
260+
f"OutcomeFalsification will store {len(self.formulas)} fitted "
261+
f"experiments (each with its own InferenceData). The combined "
262+
f"footprint can reach hundreds of MB on large datasets or "
263+
f"models with many posterior samples. Pass "
264+
f"store_experiments=False if you only need the summary "
265+
f"statistics (effect_mean, hdi_lower, hdi_upper).",
266+
stacklevel=2,
267+
)
268+
249269
results: list[FalsificationResult] = []
250270
rows: list[dict[str, Any]] = []
251271
failed_formulas: list[str] = []
@@ -285,7 +305,6 @@ def run(
285305
ValueError,
286306
KeyError,
287307
RuntimeError,
288-
np.linalg.LinAlgError,
289308
) as exc:
290309
logger.warning(
291310
"OutcomeFalsification: failed for formula '%s'",
@@ -336,7 +355,10 @@ def run(
336355
)
337356

338357
def __repr__(self) -> str:
339-
return (
340-
f"OutcomeFalsification(formulas={self.formulas!r}, "
341-
f"alpha={self.alpha}, store_experiments={self.store_experiments})"
342-
)
358+
"""Return a string representation, showing only non-default flags."""
359+
parts = [f"formulas={self.formulas!r}"]
360+
if self.alpha != 0.05:
361+
parts.append(f"alpha={self.alpha}")
362+
if not self.store_experiments:
363+
parts.append("store_experiments=False")
364+
return f"OutcomeFalsification({', '.join(parts)})"

causalpy/checks/placebo_in_time.py

Lines changed: 147 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from __future__ import annotations
3737

3838
import logging
39+
import warnings
3940
from dataclasses import dataclass, field
4041
from typing import Any, Literal
4142

@@ -54,6 +55,7 @@
5455
logger = logging.getLogger(__name__)
5556

5657
MIN_FOLD_OBSERVATIONS = 3
58+
MAX_RANDOM_SELECTION_RETRIES = 16
5759

5860
_DEFAULT_SAMPLE_KWARGS: dict[str, Any] = {
5961
"draws": 1000,
@@ -153,15 +155,32 @@ class PlaceboInTime:
153155
min_training_pct : float, default 0.30
154156
*(random mode only)* Minimum fraction of total pre-period
155157
observations that must precede each candidate placebo window.
158+
159+
Note: the eligible pre-period is further shortened because a
160+
candidate's pseudo-intervention window must also end before the
161+
actual treatment. When ``treatment_end_time`` is not set on
162+
the experiment, ``intervention_length`` defaults to
163+
``data.index.max() - treatment_time`` (roughly the post-period
164+
length), which can make the effective eligible window much
165+
smaller than ``(1 - min_training_pct)`` suggests.
156166
min_gap : int, default 1
157167
*(random mode only)* Minimum number of pre-intervention
158168
observations between any two selected folds, measured as
159-
positions in the sorted pre-period index. Prevents
160-
near-duplicate folds. Note: selection is greedy without
161-
backtracking, so very large ``min_gap`` values relative to
162-
the candidate pool can raise ``ValueError`` even when a
163-
valid configuration exists. Rerun with a different
164-
``random_seed`` or reduce ``min_gap`` if this happens.
169+
positions in the sorted pre-period index. The default of ``1``
170+
only forbids picking the same candidate twice; use a larger
171+
value to spread folds further apart. When ``allow_overlap``
172+
is ``False`` (the default) non-overlap of pseudo-intervention
173+
windows is enforced independently of ``min_gap``.
174+
allow_overlap : bool, default False
175+
*(random mode only)* If ``False`` (the default), selected
176+
pseudo-intervention windows are required to be non-overlapping
177+
in index/time units. Two folds at times ``t_a`` and ``t_b``
178+
are considered non-overlapping when
179+
``abs(t_a - t_b) >= intervention_length``. Set to ``True`` to
180+
allow overlapping windows, which relaxes the constraint at the
181+
cost of violating the exchangeability assumption of the
182+
hierarchical status-quo model (each fold mean is treated as an
183+
independent draw from a common ``mu_status_quo``).
165184
exclude_periods : set[str] | None, default None
166185
*(random mode only)* Set of period labels to exclude from
167186
candidate selection. For datetime-indexed data, use
@@ -225,6 +244,7 @@ def __init__(
225244
selection_method: Literal["sequential", "random"] = "sequential",
226245
min_training_pct: float = 0.30,
227246
min_gap: int = 1,
247+
allow_overlap: bool = False,
228248
exclude_periods: set[str] | None = None,
229249
experiment_factory: Any | None = None,
230250
sample_kwargs: dict[str, Any] | None = None,
@@ -258,6 +278,7 @@ def __init__(
258278
self.selection_method = selection_method
259279
self.min_training_pct = min_training_pct
260280
self.min_gap = min_gap
281+
self.allow_overlap = allow_overlap
261282
self.exclude_periods = exclude_periods
262283
self.experiment_factory = experiment_factory
263284
self.sample_kwargs = {**_DEFAULT_SAMPLE_KWARGS, **(sample_kwargs or {})}
@@ -361,8 +382,15 @@ def _compute_random_fold_treatment_times(
361382
362383
Builds a list of eligible candidate dates/indices from the
363384
pre-intervention data, then randomly selects ``n_folds``
364-
candidates subject to ``min_training_pct``, ``min_gap``, and
365-
``exclude_periods`` constraints.
385+
candidates subject to ``min_training_pct``, ``min_gap``,
386+
``allow_overlap``, and ``exclude_periods`` constraints.
387+
388+
Selection is greedy without backtracking and can fail to pick
389+
a valid combination on the first try when constraints are
390+
tight. To preserve reproducibility when a ``random_seed`` is
391+
supplied, the method retries up to
392+
:data:`MAX_RANDOM_SELECTION_RETRIES` times using deterministic
393+
sub-seeds derived from ``random_seed`` before giving up.
366394
367395
Parameters
368396
----------
@@ -381,7 +409,9 @@ def _compute_random_fold_treatment_times(
381409
Raises
382410
------
383411
ValueError
384-
If not enough eligible candidates exist.
412+
If not enough eligible candidates exist, or if no feasible
413+
selection is found after ``MAX_RANDOM_SELECTION_RETRIES``
414+
greedy attempts.
385415
"""
386416
pre_data = data.loc[data.index < treatment_time]
387417
if pre_data.empty:
@@ -397,19 +427,16 @@ def _compute_random_fold_treatment_times(
397427
# between selected folds, not a candidate-list distance.
398428
candidates: list[tuple[int, Any]] = []
399429
for pos, idx_val in enumerate(all_indices):
400-
# Check exclusion
401430
if hasattr(idx_val, "strftime"):
402431
label = idx_val.strftime("%Y-%m")
403432
else:
404433
label = str(idx_val)
405434
if label in exclude:
406435
continue
407436

408-
# Enough training data before this point?
409437
if pos < min_training:
410438
continue
411439

412-
# Pseudo-intervention must end before true treatment
413440
pseudo_end = idx_val + intervention_length
414441
if pseudo_end > treatment_time:
415442
continue
@@ -423,32 +450,95 @@ def _compute_random_fold_treatment_times(
423450
f"lower min_training_pct, or relax exclude_periods."
424451
)
425452

426-
rng = np.random.default_rng(self.random_seed)
453+
last_err: ValueError | None = None
454+
for attempt in range(MAX_RANDOM_SELECTION_RETRIES):
455+
# Deterministic sub-seeds: successive attempts reshuffle
456+
# choices in a reproducible way when ``random_seed`` is set
457+
# and remain non-deterministic (as expected) when it isn't.
458+
sub_seed: int | None
459+
if self.random_seed is None:
460+
sub_seed = None
461+
else:
462+
sub_seed = int(self.random_seed) + attempt
463+
rng = np.random.default_rng(sub_seed)
464+
try:
465+
selected = self._try_greedy_selection(
466+
candidates, intervention_length, rng
467+
)
468+
return sorted(candidates[i][1] for i in selected)
469+
except ValueError as err:
470+
last_err = err
471+
continue
472+
473+
raise ValueError(
474+
f"Cannot select {self.n_folds} folds with min_gap="
475+
f"{self.min_gap} and allow_overlap={self.allow_overlap} "
476+
f"after {MAX_RANDOM_SELECTION_RETRIES} greedy attempts with "
477+
f"deterministic sub-seeds. Relax constraints "
478+
f"(smaller min_gap, set allow_overlap=True, or reduce "
479+
f"n_folds). Last underlying error: {last_err}"
480+
)
481+
482+
def _try_greedy_selection(
483+
self,
484+
candidates: list[tuple[int, Any]],
485+
intervention_length: Any,
486+
rng: np.random.Generator,
487+
) -> list[int]:
488+
"""Single greedy pass over candidates; raises on infeasibility.
489+
490+
Enforces two constraints between any pair of selected folds:
491+
492+
* ``min_gap`` positional distance in the candidate index.
493+
* When ``allow_overlap`` is ``False``, non-overlap of the
494+
pseudo-intervention windows, expressed in the same units as
495+
``intervention_length``. Two windows ``[t_a, t_a + L)`` and
496+
``[t_b, t_b + L)`` are non-overlapping iff
497+
``abs(t_a - t_b) >= L``.
498+
"""
427499
pool = list(range(len(candidates)))
428500
selected: list[int] = []
429501

430502
for _ in range(self.n_folds):
431-
valid = [
432-
i
433-
for i in pool
434-
if all(
435-
abs(candidates[i][0] - candidates[s][0]) >= self.min_gap
436-
for s in selected
437-
)
438-
]
503+
valid: list[int] = []
504+
for i in pool:
505+
pos_i, idx_val_i = candidates[i]
506+
ok = True
507+
for s in selected:
508+
pos_s, idx_val_s = candidates[s]
509+
if abs(pos_i - pos_s) < self.min_gap:
510+
ok = False
511+
break
512+
if not self.allow_overlap and self._windows_overlap(
513+
idx_val_i, idx_val_s, intervention_length
514+
):
515+
ok = False
516+
break
517+
if ok:
518+
valid.append(i)
439519
if not valid:
440520
raise ValueError(
441-
f"Cannot select {self.n_folds} folds with min_gap="
442-
f"{self.min_gap}. Greedy selection without backtracking "
443-
f"can fail even when a valid configuration exists; "
444-
f"retry with a different random_seed, or reduce "
445-
f"min_gap or n_folds."
521+
"No candidate remaining satisfies min_gap and "
522+
"non-overlap constraints; greedy selection stuck."
446523
)
447524
pick = int(rng.choice(valid))
448525
selected.append(pick)
449526
pool.remove(pick)
527+
return selected
450528

451-
return sorted(candidates[i][1] for i in selected)
529+
@staticmethod
530+
def _windows_overlap(idx_a: Any, idx_b: Any, intervention_length: Any) -> bool:
531+
"""Return True if two intervention windows overlap.
532+
533+
A window starting at ``idx`` spans ``[idx, idx + intervention_length)``.
534+
Two windows overlap iff their absolute separation is strictly
535+
less than ``intervention_length``. The comparison uses
536+
``idx + intervention_length`` rather than computing a Timedelta
537+
directly so that it works uniformly for numeric indices and
538+
for datetime indices combined with ``pd.DateOffset``.
539+
"""
540+
earlier, later = (idx_a, idx_b) if idx_a <= idx_b else (idx_b, idx_a)
541+
return later < earlier + intervention_length
452542

453543
def _get_fold_data(
454544
self,
@@ -609,11 +699,38 @@ def bayesian_rope_decision(
609699
# ------------------------------------------------------------------
610700

611701
def _draw_expected_effect_samples(self, n: int) -> np.ndarray:
612-
"""Draw samples from the expected-effect prior."""
702+
"""Draw samples from the expected-effect prior.
703+
704+
Parameters
705+
----------
706+
n : int
707+
Desired number of samples. Objects with an ``.rvs(n)``
708+
method receive ``n`` directly. Pre-drawn numpy arrays are
709+
returned as-is, and :meth:`_compute_assurance` cycles
710+
through them via ``i % len(prior)`` when the array is
711+
shorter than the number of replications. A warning is
712+
emitted in this case because short arrays can introduce
713+
spurious structure in the simulated decisions.
714+
715+
Returns
716+
-------
717+
np.ndarray
718+
Samples from the expected-effect prior.
719+
"""
613720
prior = self.expected_effect_prior
614721
if prior is None:
615722
raise ValueError("expected_effect_prior is not set.")
616723
if isinstance(prior, np.ndarray):
724+
if len(prior) < n:
725+
warnings.warn(
726+
f"expected_effect_prior has {len(prior)} samples, fewer "
727+
f"than the {n} replications requested by the assurance "
728+
f"simulation; the array will be cycled through via "
729+
f"index % len(prior). Pass a longer array or an object "
730+
f"with an .rvs(n) method (e.g. a PreliZ/scipy "
731+
f"distribution) to avoid cycling.",
732+
stacklevel=2,
733+
)
617734
return prior
618735
if hasattr(prior, "rvs"):
619736
return np.asarray(prior.rvs(n)) # type: ignore[union-attr]
@@ -905,6 +1022,8 @@ def __repr__(self) -> str:
9051022
parts = [f"n_folds={self.n_folds}"]
9061023
if self.selection_method != "sequential":
9071024
parts.append(f"selection_method={self.selection_method!r}")
1025+
if self.allow_overlap:
1026+
parts.append("allow_overlap=True")
9081027
if self.expected_effect_prior is not None:
9091028
parts.append("assurance=True")
9101029
return f"PlaceboInTime({', '.join(parts)})"

0 commit comments

Comments
 (0)