|
| 1 | +# ARCHITECTURE |
| 2 | + |
| 3 | +Internal architecture reference for AI agents and contributors working on CausalPy. |
| 4 | + |
| 5 | +## Module Map |
| 6 | + |
| 7 | +| Path | Purpose | |
| 8 | +|------|---------| |
| 9 | +| `causalpy/__init__.py` | Public API surface — re-exports all experiment classes, models, pipeline, steps, transforms | |
| 10 | +| `causalpy/experiments/` | Package of experiment classes (one per file) plus `base.py` | |
| 11 | +| `causalpy/experiments/base.py` | `BaseExperiment` ABC — dispatch logic, `_render_plot`, maketables hooks | |
| 12 | +| `causalpy/pymc_models.py` | All `PyMCModel` subclasses (Bayesian backend) | |
| 13 | +| `causalpy/skl_models.py` | `ScikitLearnAdaptor` mixin, `WeightedProportion`, `create_causalpy_compatible_class()` | |
| 14 | +| `causalpy/reporting.py` | `EffectSummary` dataclass, statistics computation, prose generation for both backends | |
| 15 | +| `causalpy/maketables_adapters.py` | Backend-specific adapters for optional `maketables` table export | |
| 16 | +| `causalpy/pipeline.py` | `Pipeline`, `PipelineContext`, `PipelineResult`, `Step` protocol | |
| 17 | +| `causalpy/steps/` | Pipeline steps: `EstimateEffect`, `SensitivityAnalysis`, `GenerateReport` | |
| 18 | +| `causalpy/checks/` | Diagnostic checks: `PlaceboInTime`, `PlaceboInSpace`, `LeaveOneOut`, `ConvexHullCheck`, `BandwidthSensitivity`, `McCraryDensityTest`, `PriorSensitivity`, `PersistenceCheck`, `PreTreatmentPlaceboCheck` | |
| 19 | +| `causalpy/transforms.py` | Patsy stateful transforms `step()` and `ramp()` for piecewise ITS | |
| 20 | +| `causalpy/variable_selection_priors.py` | Spike-and-slab and horseshoe priors for IV variable selection | |
| 21 | +| `causalpy/constants.py` | `HDI_PROB` (0.94), `LEGEND_FONT_SIZE` (12) | |
| 22 | +| `causalpy/custom_exceptions.py` | `BadIndexException`, `FormulaException`, `DataException` | |
| 23 | +| `causalpy/utils.py` | Shared helpers: `round_num`, `_as_scalar`, `extract_lift_for_mmm`, `plot_correlations`, formula parsing utils | |
| 24 | +| `causalpy/plot_utils.py` | `plot_xY` (HDI ribbon helper), `get_hdi_to_df` | |
| 25 | +| `causalpy/date_utils.py` | Date axis formatting for matplotlib (`format_date_axes`, `_combine_datetime_indices`) | |
| 26 | +| `causalpy/data/` | `load_data()` and `simulate_data` module for example/synthetic datasets | |
| 27 | +| `causalpy/version.py` | `__version__` string | |
| 28 | +| `causalpy/tests/` | pytest suite — integration tests per backend, unit tests for models/reporting/checks | |
| 29 | +| `docs/source/notebooks/` | Jupyter how-to notebooks (named `{method}_{backend}.ipynb`) | |
| 30 | +| `docs/source/knowledgebase/` | Educational content (glossary, reporting statistics explainers) | |
| 31 | + |
| 32 | +## Backend Architecture |
| 33 | + |
| 34 | +CausalPy supports two model backends dispatched via `isinstance` checks: |
| 35 | + |
| 36 | +### Dispatch Pattern |
| 37 | + |
| 38 | +```python |
| 39 | +if isinstance(self.model, PyMCModel): |
| 40 | + # Bayesian path — xarray DataArrays, InferenceData |
| 41 | +elif isinstance(self.model, RegressorMixin): |
| 42 | + # OLS path — numpy arrays |
| 43 | +``` |
| 44 | + |
| 45 | +This pattern appears in `BaseExperiment.__init__` (validation), `_render_plot` (plotting), `get_plot_data` (data export), and each experiment's `algorithm()` and `effect_summary()`. |
| 46 | + |
| 47 | +### PyMCModel (Bayesian backend) |
| 48 | + |
| 49 | +`PyMCModel` extends `pymc.Model` and provides a sklearn-like interface: |
| 50 | + |
| 51 | +| Method | Contract | |
| 52 | +|--------|----------| |
| 53 | +| `build_model(X, y, coords)` | Define PyMC model graph. Must register `pm.Data("X", ...)` and `pm.Data("y", ...)`, create `mu` Deterministic and `y_hat` likelihood with dims `["obs_ind", "treated_units"]` | |
| 54 | +| `fit(X, y, coords)` | Calls `build_model`, then `pm.sample`, `sample_prior_predictive`, `sample_posterior_predictive`. Returns `az.InferenceData` | |
| 55 | +| `predict(X)` | Calls `_data_setter(X)` then `sample_posterior_predictive` for `["y_hat", "mu"]`. Returns `az.InferenceData` | |
| 56 | +| `score(X, y)` | Computes Bayesian R² per treated unit. Returns `pd.Series` | |
| 57 | +| `calculate_impact(y_true, y_pred)` | `y_true - y_pred["posterior_predictive"]["mu"]` (uses mu, NOT y_hat) | |
| 58 | +| `print_coefficients(labels)` | Prints posterior mean + HDI for each coefficient | |
| 59 | + |
| 60 | +Priors use `pymc_extras.Prior` objects. Priority: user-specified > `priors_from_data()` > `default_priors` class attribute. |
| 61 | + |
| 62 | +### ScikitLearnAdaptor (OLS backend) |
| 63 | + |
| 64 | +`ScikitLearnAdaptor` is a mixin class providing CausalPy-compatible methods: |
| 65 | + |
| 66 | +- `calculate_impact(y_true, y_pred)` → `y_true - y_pred` (numpy subtraction) |
| 67 | +- `calculate_cumulative_impact(impact)` → `np.cumsum(impact)` |
| 68 | +- `print_coefficients(labels)` → prints `coef_` values |
| 69 | +- `get_coeffs()` → `np.squeeze(self.coef_)` |
| 70 | + |
| 71 | +`create_causalpy_compatible_class(estimator)` takes an instantiated sklearn `RegressorMixin` and monkey-patches `ScikitLearnAdaptor` methods onto it via `_add_mixin_methods`. Returns the mutated instance (not a new class). |
| 72 | + |
| 73 | +### supports_ols / supports_bayes |
| 74 | + |
| 75 | +Every experiment class declares these as class attributes. `BaseExperiment.__init__` raises `ValueError` if the wrong model type is passed. |
| 76 | + |
| 77 | +### _default_model_class |
| 78 | + |
| 79 | +When `model=None` is passed to an experiment, `BaseExperiment.__init__` instantiates `_default_model_class()` with no arguments. This always produces a Bayesian model. Experiments without a default (e.g. `PanelRegression`) raise `ValueError` if `model=None`. |
| 80 | + |
| 81 | +## Experiment Lifecycle |
| 82 | + |
| 83 | +### 1. Instantiation |
| 84 | + |
| 85 | +```text |
| 86 | +ExperimentClass(data, ..., model=None) |
| 87 | + → BaseExperiment.__init__(model) |
| 88 | + → wrap sklearn model via create_causalpy_compatible_class() if needed |
| 89 | + → instantiate _default_model_class if model is None |
| 90 | + → validate supports_ols / supports_bayes |
| 91 | + → self.data = data; self.formula = formula |
| 92 | + → input_validation(...) |
| 93 | + → _build_design_matrices() |
| 94 | + → _prepare_data() (convert to xarray DataArrays) |
| 95 | + → algorithm() |
| 96 | +``` |
| 97 | + |
| 98 | +Most experiments fit eagerly in `__init__` — instantiation triggers the full MCMC run. There is no separate `.fit()` on the experiment (only on the model). |
| 99 | + |
| 100 | +### 2. _build_design_matrices() |
| 101 | + |
| 102 | +Uses patsy `dmatrices(formula, data_pre)` for pre-intervention data, stores `design_info`. Uses `build_design_matrices([y_design_info, x_design_info], data_post)` for post-intervention data — this ensures consistent encoding for out-of-sample prediction. |
| 103 | + |
| 104 | +### 3. _prepare_data() |
| 105 | + |
| 106 | +Converts numpy design matrices into `xr.DataArray` with dims `["obs_ind", "coeffs"]` for X and `["obs_ind", "treated_units"]` for y. |
| 107 | + |
| 108 | +### 4. algorithm() |
| 109 | + |
| 110 | +Per-experiment fitting logic: |
| 111 | +1. `model.fit(X_pre, y_pre, coords)` — train on pre-intervention data |
| 112 | +2. `model.score(X_pre, y_pre)` — evaluate fit quality |
| 113 | +3. `model.predict(X_pre)` — in-sample predictions |
| 114 | +4. `model.predict(X_post)` — counterfactual predictions |
| 115 | +5. `model.calculate_impact(y_post, post_pred)` — causal effect |
| 116 | +6. `model.calculate_cumulative_impact(impact)` — cumulative effect |
| 117 | + |
| 118 | +### 5. _render_plot() |
| 119 | + |
| 120 | +Template method called by each subclass's public `plot()`: |
| 121 | +1. Applies `arviz-darkgrid` style context |
| 122 | +2. Dispatches to `_bayesian_plot(**draw_kwargs)` or `_ols_plot(**draw_kwargs)` based on model type |
| 123 | +3. Applies `legend_kwargs` in-place to preserve custom handles |
| 124 | +4. Optionally calls `plt.show()` |
| 125 | + |
| 126 | +### 6. effect_summary() |
| 127 | + |
| 128 | +Abstract method on `BaseExperiment`. Each subclass implements it using helpers from `causalpy.reporting`: |
| 129 | +- Bayesian: `_compute_statistics()` → `_generate_table()` → `_generate_prose_detailed()` |
| 130 | +- OLS: `_compute_statistics_ols()` → `_generate_table_ols()` → `_generate_prose_detailed_ols()` |
| 131 | + |
| 132 | +Returns `EffectSummary(table=pd.DataFrame, text=str)`. |
| 133 | + |
| 134 | +### 7. get_plot_data() |
| 135 | + |
| 136 | +Dispatches to `get_plot_data_bayesian()` or `get_plot_data_ols()`. Returns a `pd.DataFrame` with columns for predictions, impacts, and HDI bounds. |
| 137 | + |
| 138 | +## Formula and Data Pipeline |
| 139 | + |
| 140 | +### patsy workflow |
| 141 | + |
| 142 | +1. `dmatrices(formula, df_pre)` → `(y, X)` numpy DesignMatrix objects |
| 143 | +2. Store `y.design_info` and `X.design_info` for later reuse |
| 144 | +3. `build_design_matrices([y_design_info, x_design_info], df_post)` → counterfactual matrices with consistent factor encoding |
| 145 | +4. `self.labels = X.design_info.column_names` — coefficient names |
| 146 | + |
| 147 | +### Custom transforms (PiecewiseITS) |
| 148 | + |
| 149 | +`step(time, threshold)` → binary indicator `(time >= threshold)` |
| 150 | +`ramp(time, threshold)` → `max(0, time - threshold)` |
| 151 | + |
| 152 | +Both are patsy `stateful_transform` objects that memorize datetime origin during first pass and convert datetime to numeric days internally. |
| 153 | + |
| 154 | +### obs_ind index naming |
| 155 | + |
| 156 | +All experiments rename `data.index.name = "obs_ind"`. This is the canonical dimension name for xarray DataArrays and PyMC model coordinates. |
| 157 | + |
| 158 | +## Data Contracts |
| 159 | + |
| 160 | +### PyMC Backend |
| 161 | + |
| 162 | +| Object | Type | Dims/Coords | |
| 163 | +|--------|------|-------------| |
| 164 | +| X (input) | `xr.DataArray` | `["obs_ind", "coeffs"]` with coord values | |
| 165 | +| y (input) | `xr.DataArray` | `["obs_ind", "treated_units"]` — always 2D | |
| 166 | +| coords dict | `dict` | Keys: `"coeffs"`, `"obs_ind"`, `"treated_units"` (required) | |
| 167 | +| fit() return | `az.InferenceData` | Contains posterior, prior_predictive, posterior_predictive | |
| 168 | +| predict() return | `az.InferenceData` | `posterior_predictive` group with `mu` and `y_hat` vars | |
| 169 | +| mu | `xr.DataArray` | Deterministic mean; dims `["chain", "draw", "obs_ind", "treated_units"]` | |
| 170 | +| y_hat | `xr.DataArray` | Observation with noise; same dims as mu | |
| 171 | +| impact | `xr.DataArray` | `y_true - mu`; trailing dim is `"obs_ind"` | |
| 172 | + |
| 173 | +Key conventions: |
| 174 | +- `treated_units` dim is **always present and 2D** even for single-unit experiments (value: `["unit_0"]`) |
| 175 | +- Impact uses `mu` (posterior expectation), NOT `y_hat` (with observation noise) |
| 176 | +- `coeffs_raw` dim appears in softmax models (N-1 logits, first pinned to zero) |
| 177 | + |
| 178 | +### sklearn Backend |
| 179 | + |
| 180 | +| Object | Type | Notes | |
| 181 | +|--------|------|-------| |
| 182 | +| X (input) | `np.ndarray` or `xr.DataArray` | 2D, shape `(n_obs, n_features)` | |
| 183 | +| y (input) | `np.ndarray` or 1D `xr.DataArray` | `.isel(treated_units=0)` before passing to sklearn | |
| 184 | +| predict() return | `np.ndarray` | Shape `(n_obs, 1)` or `(n_obs,)` | |
| 185 | +| coef_ | `np.ndarray` | Accessed via `get_coeffs()` → squeezed | |
| 186 | +| impact | `np.ndarray` | Simple `y_true - y_pred` | |
| 187 | + |
| 188 | +## Experiment Inventory |
| 189 | + |
| 190 | +| Class | Causal Method | `supports_ols` | `supports_bayes` | Default Model | Notable Quirks | |
| 191 | +|-------|--------------|-----------------|-------------------|---------------|----------------| |
| 192 | +| `InterruptedTimeSeries` | ITS (pre/post fit) | Yes | Yes | `LinearRegression` | Supports 3-period design via `treatment_end_time`; eager fit in `__init__` | |
| 193 | +| `PiecewiseITS` | ITS (segmented regression) | Yes | Yes | `LinearRegression` | Fits full time series (not pre-only); uses `step()`/`ramp()` transforms | |
| 194 | +| `DifferenceInDifferences` | DiD | Yes | Yes | `LinearRegression` | Fits all data (no pre/post split); effect from interaction coefficient | |
| 195 | +| `StaggeredDifferenceInDifferences` | Staggered DiD (BJS imputation) | Yes | Yes | `LinearRegression` | Fits untreated observations only; validates absorbing treatment | |
| 196 | +| `SyntheticControl` | SC | Yes | Yes | `WeightedSumFitter` | Multi-unit (multiple `treated_units`); no formula — uses control/treated unit lists | |
| 197 | +| `SyntheticDifferenceInDifferences` | SDiD | Yes | Yes | `SDiDWeightFitter` | Cut-posterior: tau computed analytically from weight posteriors | |
| 198 | +| `RegressionDiscontinuity` | RD (sharp) | Yes | Yes | `LinearRegression` | `epsilon` parameter for causal effect evaluation at threshold; optional `bandwidth` | |
| 199 | +| `RegressionKink` | RKD | No | Yes | `LinearRegression` | `kink_point` instead of threshold; evaluates slope change | |
| 200 | +| `PrePostNEGD` | Pretest/posttest | No | Yes | `LinearRegression` | Uses `group_variable_name` and `pretreatment_variable_name` | |
| 201 | +| `InversePropensityWeighting` | IPW | No | Yes | `PropensityScore` | Non-standard: two-stage (propensity then outcome); no unified `plot()` | |
| 202 | +| `InstrumentalVariable` | IV/2SLS | No | Yes | `IVRegression` | Non-standard `fit()` signature (X, Z, y, t, coords, priors); no unified `plot()` | |
| 203 | +| `PanelRegression` | Panel FE | Yes | Yes | None (required) | Supports demeaned and dummy-variable FE; no `_default_model_class` | |
| 204 | + |
| 205 | +## PyMC Model Inventory |
| 206 | + |
| 207 | +| Class | Purpose | Used By | |
| 208 | +|-------|---------|---------| |
| 209 | +| `PyMCModel` | Abstract base — provides fit/predict/score/calculate_impact contract | All Bayesian experiments (via subclasses) | |
| 210 | +| `LinearRegression` | Standard linear model: `y ~ Normal(X·β, σ)` | ITS, DiD, RD, RKD, PrePostNEGD, PiecewiseITS, StaggeredDiD, PanelRegression | |
| 211 | +| `WeightedSumFitter` | Dirichlet-weighted sum: `y ~ Normal(X·β, σ)` where `β ~ Dirichlet(1)` | SyntheticControl | |
| 212 | +| `SoftmaxWeightedSumFitter` | Softmax-Normal simplex weights (alternative to Dirichlet) | SyntheticControl (alternative) | |
| 213 | +| `SyntheticDifferenceInDifferencesWeightFitter` | Joint unit + time weight model for SDiD | SyntheticDifferenceInDifferences | |
| 214 | +| `InstrumentalVariableRegression` | 2SLS with correlated errors (LKJ covariance or binary treatment) | InstrumentalVariable | |
| 215 | +| `PropensityScore` | Logistic propensity model: `t ~ Bernoulli(logit⁻¹(X·b))` | InversePropensityWeighting | |
| 216 | +| `BayesianBasisExpansionTimeSeries` | Trend + seasonality via pymc-marketing components (experimental) | InterruptedTimeSeries (alternative) | |
| 217 | +| `StateSpaceTimeSeries` | State-space model via pymc-extras structural (experimental) | InterruptedTimeSeries (alternative) | |
| 218 | + |
| 219 | +## Extension Guide |
| 220 | + |
| 221 | +### Add a new experiment class |
| 222 | + |
| 223 | +1. Create `causalpy/experiments/your_method.py` |
| 224 | +2. Subclass `BaseExperiment` |
| 225 | +3. Set `supports_ols`, `supports_bayes`, and optionally `_default_model_class` |
| 226 | +4. Implement `__init__` calling `super().__init__(model=model)` then `_build_design_matrices()` → `algorithm()` |
| 227 | +5. Implement `algorithm()` with the fit/predict/impact flow |
| 228 | +6. Implement `_bayesian_plot()` and/or `_ols_plot()` (only for supported backends) |
| 229 | +7. Implement `effect_summary()` using helpers from `causalpy.reporting` |
| 230 | +8. Declare an explicit public `plot(*, ...)` method with kwarg-only signature that calls `self._render_plot(...)` |
| 231 | +9. Export from `causalpy/experiments/__init__.py` and `causalpy/__init__.py` |
| 232 | + |
| 233 | +### Add a new PyMC model |
| 234 | + |
| 235 | +1. Add class to `causalpy/pymc_models.py` inheriting from `PyMCModel` |
| 236 | +2. Implement `build_model(X, y, coords)` — must create `pm.Data("X", ...)`, `pm.Data("y", ...)`, a `pm.Deterministic("mu", ..., dims=["obs_ind", "treated_units"])`, and a likelihood named `"y_hat"` |
| 237 | +3. Set `default_priors` dict with `Prior` objects |
| 238 | +4. Optionally override `priors_from_data(X, y)` for data-adaptive priors |
| 239 | +5. Optionally override `_data_setter(X)` if prediction requires custom data updates |
| 240 | +6. If `fit()` signature differs from base (non-standard arguments), override it with `# type: ignore[override]` |
| 241 | + |
| 242 | +### Add a new sklearn-compatible model |
| 243 | + |
| 244 | +1. Create a class inheriting from both `ScikitLearnAdaptor` and sklearn's `LinearModel` + `RegressorMixin` |
| 245 | +2. Implement `fit(X, y)` and `predict(X)` — store coefficients in `self.coef_` as 2D array |
| 246 | +3. Alternatively, pass any fitted sklearn `RegressorMixin` instance — `create_causalpy_compatible_class()` will monkey-patch the adapter methods automatically |
| 247 | + |
| 248 | +### Add a new plotting backend or report format |
| 249 | + |
| 250 | +- For plots: override `_bayesian_plot()` / `_ols_plot()` in experiment subclass, or create a new dispatch in `_render_plot()` |
| 251 | +- For reports: extend `causalpy/steps/report.py` (`GenerateReport` step produces HTML); the experiment's `generate_report()` method wraps this |
| 252 | +- For table export: add a new adapter in `causalpy/maketables_adapters.py` implementing the `MaketablesAdapter` protocol |
| 253 | + |
| 254 | +## Key Conventions and Gotchas |
| 255 | + |
| 256 | +| Topic | Detail | |
| 257 | +|-------|--------| |
| 258 | +| **Intercept handling** | Patsy includes an intercept by default (`1 +` in formula). sklearn models must use `fit_intercept=False` because the intercept is already in the design matrix as a column of ones. | |
| 259 | +| **treated_units is always 2D** | Even single-unit experiments use `treated_units=["unit_0"]`. y is always shape `(n_obs, n_treated)`. Never pass 1D y to a PyMC model. | |
| 260 | +| **Impact uses mu, not y_hat** | `calculate_impact()` subtracts `posterior_predictive["mu"]` (expected value), not `["y_hat"]` (with observation noise). This gives cleaner effect estimates reflecting only parameter uncertainty. | |
| 261 | +| **labels must align with coefficients** | `self.labels` comes from `X.design_info.column_names` and must match the `coeffs` dimension of the posterior `beta` variable exactly in order and length. | |
| 262 | +| **maketables adapter dispatch** | `get_maketables_adapter(model)` mirrors the `isinstance` dispatch: `PyMCModel` → `PyMCMaketablesAdapter`, `RegressorMixin` → `SklearnMaketablesAdapter`. | |
| 263 | +| **obs_ind index naming** | Experiments rename `data.index.name = "obs_ind"` early. All xarray dims use this name. PyMC coords key must be `"obs_ind"`. | |
| 264 | +| **design_info for out-of-sample** | Store `_x_design_info` and `_y_design_info` from `dmatrices()`. Use `build_design_matrices([info], new_data)` for counterfactual prediction to preserve factor encoding. | |
| 265 | +| **Eager fitting in __init__** | Most experiments run MCMC during `__init__`. There is no lazy `.fit()` — the experiment object is fully fitted upon construction. | |
| 266 | +| **HDI_PROB default** | The project uses 0.94 (matching ArviZ default), NOT 0.95. `effect_summary()` defaults to `alpha=0.05` (95% HDI), which is independent of `HDI_PROB`. | |
| 267 | +| **SyntheticControl multi-unit** | `SyntheticControl` loops over `treated_units` fitting one model per unit (via `_clone()`). Each unit gets its own `pre_pred`, `post_pred`, `impact`. | |
| 268 | +| **SDiD cut-posterior** | Treatment effect tau is NOT estimated inside the MCMC model. Unit and time weights are sampled jointly, then tau is computed analytically via double-differencing the observed data with the weight posteriors. | |
| 269 | +| **InstrumentalVariable non-standard fit** | `IVRegression.fit(X, Z, y, t, coords, priors, ...)` — does NOT follow the base `fit(X, y, coords)` signature. The experiment class handles this internally. | |
| 270 | +| **create_causalpy_compatible_class mutates** | This function mutates the passed instance (adds methods), it does NOT create a new class or return a new instance. The name is misleading. | |
| 271 | +| **Pipeline vs direct instantiation** | Experiments can be used standalone (just instantiate) or via `Pipeline` with steps. The pipeline adds `SensitivityAnalysis` and `GenerateReport` on top. | |
0 commit comments