3636from __future__ import annotations
3737
3838import logging
39+ import warnings
3940from dataclasses import dataclass , field
4041from typing import Any , Literal
4142
5455logger = logging .getLogger (__name__ )
5556
5657MIN_FOLD_OBSERVATIONS = 3
58+ MAX_RANDOM_SELECTION_RETRIES = 16
5759
5860_DEFAULT_SAMPLE_KWARGS : dict [str , Any ] = {
5961 "draws" : 1000 ,
@@ -153,15 +155,32 @@ class PlaceboInTime:
153155 min_training_pct : float, default 0.30
154156 *(random mode only)* Minimum fraction of total pre-period
155157 observations that must precede each candidate placebo window.
158+
159+ Note: the eligible pre-period is further shortened because a
160+ candidate's pseudo-intervention window must also end before the
161+ actual treatment. When ``treatment_end_time`` is not set on
162+ the experiment, ``intervention_length`` defaults to
163+ ``data.index.max() - treatment_time`` (roughly the post-period
164+ length), which can make the effective eligible window much
165+ smaller than ``(1 - min_training_pct)`` suggests.
156166 min_gap : int, default 1
157167 *(random mode only)* Minimum number of pre-intervention
158168 observations between any two selected folds, measured as
159- positions in the sorted pre-period index. Prevents
160- near-duplicate folds. Note: selection is greedy without
161- backtracking, so very large ``min_gap`` values relative to
162- the candidate pool can raise ``ValueError`` even when a
163- valid configuration exists. Rerun with a different
164- ``random_seed`` or reduce ``min_gap`` if this happens.
169+ positions in the sorted pre-period index. The default of ``1``
170+ only forbids picking the same candidate twice; use a larger
171+ value to spread folds further apart. When ``allow_overlap``
172+ is ``False`` (the default) non-overlap of pseudo-intervention
173+ windows is enforced independently of ``min_gap``.
174+ allow_overlap : bool, default False
175+ *(random mode only)* If ``False`` (the default), selected
176+ pseudo-intervention windows are required to be non-overlapping
177+ in index/time units. Two folds at times ``t_a`` and ``t_b``
178+ are considered non-overlapping when
179+ ``abs(t_a - t_b) >= intervention_length``. Set to ``True`` to
180+ allow overlapping windows, which relaxes the constraint at the
181+ cost of violating the exchangeability assumption of the
182+ hierarchical status-quo model (each fold mean is treated as an
183+ independent draw from a common ``mu_status_quo``).
165184 exclude_periods : set[str] | None, default None
166185 *(random mode only)* Set of period labels to exclude from
167186 candidate selection. For datetime-indexed data, use
@@ -225,6 +244,7 @@ def __init__(
225244 selection_method : Literal ["sequential" , "random" ] = "sequential" ,
226245 min_training_pct : float = 0.30 ,
227246 min_gap : int = 1 ,
247+ allow_overlap : bool = False ,
228248 exclude_periods : set [str ] | None = None ,
229249 experiment_factory : Any | None = None ,
230250 sample_kwargs : dict [str , Any ] | None = None ,
@@ -258,6 +278,7 @@ def __init__(
258278 self .selection_method = selection_method
259279 self .min_training_pct = min_training_pct
260280 self .min_gap = min_gap
281+ self .allow_overlap = allow_overlap
261282 self .exclude_periods = exclude_periods
262283 self .experiment_factory = experiment_factory
263284 self .sample_kwargs = {** _DEFAULT_SAMPLE_KWARGS , ** (sample_kwargs or {})}
@@ -361,8 +382,15 @@ def _compute_random_fold_treatment_times(
361382
362383 Builds a list of eligible candidate dates/indices from the
363384 pre-intervention data, then randomly selects ``n_folds``
364- candidates subject to ``min_training_pct``, ``min_gap``, and
365- ``exclude_periods`` constraints.
385+ candidates subject to ``min_training_pct``, ``min_gap``,
386+ ``allow_overlap``, and ``exclude_periods`` constraints.
387+
388+ Selection is greedy without backtracking and can fail to pick
389+ a valid combination on the first try when constraints are
390+ tight. To preserve reproducibility when a ``random_seed`` is
391+ supplied, the method retries up to
392+ :data:`MAX_RANDOM_SELECTION_RETRIES` times using deterministic
393+ sub-seeds derived from ``random_seed`` before giving up.
366394
367395 Parameters
368396 ----------
@@ -381,7 +409,9 @@ def _compute_random_fold_treatment_times(
381409 Raises
382410 ------
383411 ValueError
384- If not enough eligible candidates exist.
412+ If not enough eligible candidates exist, or if no feasible
413+ selection is found after ``MAX_RANDOM_SELECTION_RETRIES``
414+ greedy attempts.
385415 """
386416 pre_data = data .loc [data .index < treatment_time ]
387417 if pre_data .empty :
@@ -397,19 +427,16 @@ def _compute_random_fold_treatment_times(
397427 # between selected folds, not a candidate-list distance.
398428 candidates : list [tuple [int , Any ]] = []
399429 for pos , idx_val in enumerate (all_indices ):
400- # Check exclusion
401430 if hasattr (idx_val , "strftime" ):
402431 label = idx_val .strftime ("%Y-%m" )
403432 else :
404433 label = str (idx_val )
405434 if label in exclude :
406435 continue
407436
408- # Enough training data before this point?
409437 if pos < min_training :
410438 continue
411439
412- # Pseudo-intervention must end before true treatment
413440 pseudo_end = idx_val + intervention_length
414441 if pseudo_end > treatment_time :
415442 continue
@@ -423,32 +450,95 @@ def _compute_random_fold_treatment_times(
423450 f"lower min_training_pct, or relax exclude_periods."
424451 )
425452
426- rng = np .random .default_rng (self .random_seed )
453+ last_err : ValueError | None = None
454+ for attempt in range (MAX_RANDOM_SELECTION_RETRIES ):
455+ # Deterministic sub-seeds: successive attempts reshuffle
456+ # choices in a reproducible way when ``random_seed`` is set
457+ # and remain non-deterministic (as expected) when it isn't.
458+ sub_seed : int | None
459+ if self .random_seed is None :
460+ sub_seed = None
461+ else :
462+ sub_seed = int (self .random_seed ) + attempt
463+ rng = np .random .default_rng (sub_seed )
464+ try :
465+ selected = self ._try_greedy_selection (
466+ candidates , intervention_length , rng
467+ )
468+ return sorted (candidates [i ][1 ] for i in selected )
469+ except ValueError as err :
470+ last_err = err
471+ continue
472+
473+ raise ValueError (
474+ f"Cannot select { self .n_folds } folds with min_gap="
475+ f"{ self .min_gap } and allow_overlap={ self .allow_overlap } "
476+ f"after { MAX_RANDOM_SELECTION_RETRIES } greedy attempts with "
477+ f"deterministic sub-seeds. Relax constraints "
478+ f"(smaller min_gap, set allow_overlap=True, or reduce "
479+ f"n_folds). Last underlying error: { last_err } "
480+ )
481+
482+ def _try_greedy_selection (
483+ self ,
484+ candidates : list [tuple [int , Any ]],
485+ intervention_length : Any ,
486+ rng : np .random .Generator ,
487+ ) -> list [int ]:
488+ """Single greedy pass over candidates; raises on infeasibility.
489+
490+ Enforces two constraints between any pair of selected folds:
491+
492+ * ``min_gap`` positional distance in the candidate index.
493+ * When ``allow_overlap`` is ``False``, non-overlap of the
494+ pseudo-intervention windows, expressed in the same units as
495+ ``intervention_length``. Two windows ``[t_a, t_a + L)`` and
496+ ``[t_b, t_b + L)`` are non-overlapping iff
497+ ``abs(t_a - t_b) >= L``.
498+ """
427499 pool = list (range (len (candidates )))
428500 selected : list [int ] = []
429501
430502 for _ in range (self .n_folds ):
431- valid = [
432- i
433- for i in pool
434- if all (
435- abs (candidates [i ][0 ] - candidates [s ][0 ]) >= self .min_gap
436- for s in selected
437- )
438- ]
503+ valid : list [int ] = []
504+ for i in pool :
505+ pos_i , idx_val_i = candidates [i ]
506+ ok = True
507+ for s in selected :
508+ pos_s , idx_val_s = candidates [s ]
509+ if abs (pos_i - pos_s ) < self .min_gap :
510+ ok = False
511+ break
512+ if not self .allow_overlap and self ._windows_overlap (
513+ idx_val_i , idx_val_s , intervention_length
514+ ):
515+ ok = False
516+ break
517+ if ok :
518+ valid .append (i )
439519 if not valid :
440520 raise ValueError (
441- f"Cannot select { self .n_folds } folds with min_gap="
442- f"{ self .min_gap } . Greedy selection without backtracking "
443- f"can fail even when a valid configuration exists; "
444- f"retry with a different random_seed, or reduce "
445- f"min_gap or n_folds."
521+ "No candidate remaining satisfies min_gap and "
522+ "non-overlap constraints; greedy selection stuck."
446523 )
447524 pick = int (rng .choice (valid ))
448525 selected .append (pick )
449526 pool .remove (pick )
527+ return selected
450528
451- return sorted (candidates [i ][1 ] for i in selected )
529+ @staticmethod
530+ def _windows_overlap (idx_a : Any , idx_b : Any , intervention_length : Any ) -> bool :
531+ """Return True if two intervention windows overlap.
532+
533+ A window starting at ``idx`` spans ``[idx, idx + intervention_length)``.
534+ Two windows overlap iff their absolute separation is strictly
535+ less than ``intervention_length``. The comparison uses
536+ ``idx + intervention_length`` rather than computing a Timedelta
537+ directly so that it works uniformly for numeric indices and
538+ for datetime indices combined with ``pd.DateOffset``.
539+ """
540+ earlier , later = (idx_a , idx_b ) if idx_a <= idx_b else (idx_b , idx_a )
541+ return later < earlier + intervention_length
452542
453543 def _get_fold_data (
454544 self ,
@@ -609,11 +699,38 @@ def bayesian_rope_decision(
609699 # ------------------------------------------------------------------
610700
611701 def _draw_expected_effect_samples (self , n : int ) -> np .ndarray :
612- """Draw samples from the expected-effect prior."""
702+ """Draw samples from the expected-effect prior.
703+
704+ Parameters
705+ ----------
706+ n : int
707+ Desired number of samples. Objects with an ``.rvs(n)``
708+ method receive ``n`` directly. Pre-drawn numpy arrays are
709+ returned as-is, and :meth:`_compute_assurance` cycles
710+ through them via ``i % len(prior)`` when the array is
711+ shorter than the number of replications. A warning is
712+ emitted in this case because short arrays can introduce
713+ spurious structure in the simulated decisions.
714+
715+ Returns
716+ -------
717+ np.ndarray
718+ Samples from the expected-effect prior.
719+ """
613720 prior = self .expected_effect_prior
614721 if prior is None :
615722 raise ValueError ("expected_effect_prior is not set." )
616723 if isinstance (prior , np .ndarray ):
724+ if len (prior ) < n :
725+ warnings .warn (
726+ f"expected_effect_prior has { len (prior )} samples, fewer "
727+ f"than the { n } replications requested by the assurance "
728+ f"simulation; the array will be cycled through via "
729+ f"index % len(prior). Pass a longer array or an object "
730+ f"with an .rvs(n) method (e.g. a PreliZ/scipy "
731+ f"distribution) to avoid cycling." ,
732+ stacklevel = 2 ,
733+ )
617734 return prior
618735 if hasattr (prior , "rvs" ):
619736 return np .asarray (prior .rvs (n )) # type: ignore[union-attr]
@@ -905,6 +1022,8 @@ def __repr__(self) -> str:
9051022 parts = [f"n_folds={ self .n_folds } " ]
9061023 if self .selection_method != "sequential" :
9071024 parts .append (f"selection_method={ self .selection_method !r} " )
1025+ if self .allow_overlap :
1026+ parts .append ("allow_overlap=True" )
9081027 if self .expected_effect_prior is not None :
9091028 parts .append ("assurance=True" )
9101029 return f"PlaceboInTime({ ', ' .join (parts )} )"
0 commit comments