-
Notifications
You must be signed in to change notification settings - Fork 103
Consolidate experiment design-matrix attributes into xr.Dataset #849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4e5c0d6
08ee2c3
aafb28c
44e13ec
ba0faf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,13 +18,16 @@ | |
| from __future__ import annotations | ||
|
|
||
| import contextlib | ||
| import warnings | ||
| from abc import ABC, abstractmethod | ||
| from pathlib import Path | ||
| from typing import Any, Literal | ||
|
|
||
| import arviz as az | ||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import pandas as pd | ||
| import xarray as xr | ||
| from sklearn.base import RegressorMixin | ||
|
|
||
| from causalpy.maketables_adapters import get_maketables_adapter | ||
|
|
@@ -114,6 +117,67 @@ class BaseExperiment(ABC): | |
|
|
||
| _default_model_class: type[PyMCModel] | None = None | ||
|
|
||
| _deprecated_design_aliases: dict[str, tuple[str, str]] = {} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick one — |
||
| """Mapping of ``old_attr -> (dataset_attr, key)`` for deprecated design | ||
| matrix accessors. Subclasses populate this so that | ||
| ``__getattr__`` can forward accesses with a deprecation warning.""" | ||
|
|
||
| def __getattr__(self, name: str) -> Any: | ||
| aliases = type(self)._deprecated_design_aliases | ||
| if name in aliases: | ||
| dataset_attr, key = aliases[name] | ||
| warnings.warn( | ||
| f"{name} is deprecated, use {dataset_attr}['{key}']", | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| return getattr(self, dataset_attr)[key] | ||
| raise AttributeError( | ||
| f"'{type(self).__name__}' object has no attribute '{name}'" | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _build_design_dataset( | ||
| X_raw: np.ndarray, | ||
| y_raw: np.ndarray, | ||
| *, | ||
| obs_ind: np.ndarray | pd.Index, | ||
| coeffs: list[str], | ||
| treated_units: list[str] | None = None, | ||
| ) -> xr.Dataset: | ||
| """Build a standard ``xr.Dataset`` from raw design matrices. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X_raw : np.ndarray | ||
| Predictor matrix, shape ``(n_obs, n_coeffs)``. | ||
| y_raw : np.ndarray | ||
| Outcome matrix, shape ``(n_obs, n_units)``. | ||
| obs_ind : array-like | ||
| Observation index coordinates. | ||
| coeffs : list[str] | ||
| Coefficient / column names for ``X_raw``. | ||
| treated_units : list[str], optional | ||
| Names for the treated-unit dimension of ``y_raw``. | ||
| Defaults to ``["unit_0"]``. | ||
| """ | ||
| if treated_units is None: | ||
| treated_units = ["unit_0"] | ||
| return xr.Dataset( | ||
| { | ||
| "X": xr.DataArray( | ||
| X_raw, | ||
| dims=["obs_ind", "coeffs"], | ||
| coords={"obs_ind": obs_ind, "coeffs": coeffs}, | ||
| ), | ||
| "y": xr.DataArray( | ||
| y_raw, | ||
| dims=["obs_ind", "treated_units"], | ||
| coords={"obs_ind": obs_ind, "treated_units": treated_units}, | ||
| ), | ||
| } | ||
| ) | ||
|
|
||
| def __init__(self, model: PyMCModel | RegressorMixin | None = None) -> None: | ||
| # Ensure we've made any provided Scikit Learn model (as identified as being type | ||
| # RegressorMixin) compatible with CausalPy by appending our custom methods. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tiny nit: since
applicable_methods = {SyntheticControl}andvalidate()already enforces the type, you could drop both# type: ignore[attr-defined]markers by addingassert isinstance(experiment, SyntheticControl)(or just callingself.validate(experiment)) at the top ofrun()— that narrowssctoSyntheticControlandpre_designbecomes a known attribute for mypy. Same effect, no escape hatches.