diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb new file mode 100644 index 000000000..4d27bcb6d --- /dev/null +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -0,0 +1,543 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f9e8547b", + "metadata": {}, + "source": [ + "# Notes to self:\n", + "Currently, I am able to get the model to sample with batch dimensions and I am able to use the following internal methods:\n", + "\n", + "* `sample_unconditional_prior`\n", + "* `sample_conditional_prior`\n", + "* `sample_unconditional_posterior`\n", + "* `sample_conditional_posterior`\n", + "* `forecast`\n", + "* `sample_filter_outputs`\n", + "* `sample_statespace_matrices`\n", + "* `impulse_response_function`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d95e629f", + "metadata": {}, + "outputs": [], + "source": [ + "import arviz as az\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pymc as pm\n", + "import pytensor.tensor as pt\n", + "\n", + "from pymc_extras.statespace.core.statespace import PyMCStateSpace\n", + "from pymc_extras.statespace.filters import StandardFilter, KalmanSmoother\n", + "\n", + "from pymc_extras.statespace.core.properties import (\n", + " Parameter,\n", + " State,\n", + " Shock,\n", + " Coord,\n", + ")\n", + "from pymc_extras.statespace.utils.constants import ALL_STATE_DIM, ALL_STATE_AUX_DIM, SHOCK_DIM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34b32aa3", + "metadata": {}, + "outputs": [], + "source": [ + "class AutoRegressiveThree(PyMCStateSpace):\n", + " def __init__(self, mode: str):\n", + " k_states = 3 # size of the state vector x\n", + " k_posdef = 1 # number of shocks (size of the state covariance matrix Q)\n", + " k_endog = 1 # number of observed states\n", + "\n", + " super().__init__(\n", + " k_endog=k_endog,\n", + " k_states=k_states,\n", + " k_posdef=k_posdef,\n", + " mode=mode,\n", + " )\n", + "\n", + " def make_symbolic_graph(self):\n", + " x0 = self.make_and_register_variable(\"x0\", shape=(3,))\n", + " P0 = self.make_and_register_variable(\"P0\", shape=(3, 3))\n", + "\n", + " ar_params = self.make_and_register_variable(\"ar_params\", shape=(3,))\n", + " sigma_x = self.make_and_register_variable(\"sigma_x\", shape=(1,))\n", + "\n", + " self.ssm[\"transition\", :, :] = np.eye(3, k=-1)\n", + " self.ssm[\"selection\", 0, 0] = 1\n", + " self.ssm[\"design\", 0, 0] = 1\n", + "\n", + " self.ssm[\"initial_state\", :] = x0\n", + " self.ssm[\"initial_state_cov\", :, :] = P0\n", + " self.ssm[\"transition\", 0, :] = ar_params\n", + " self.ssm[\"state_cov\", :, :] = sigma_x\n", + "\n", + " def set_parameters(self):\n", + " # Only the \"name\" parameter is required here. \"Shape\" is only used when printing the\n", + " # model requirements table. \"Dims\" are used to link variables to coords.\n", + " x0 = Parameter(name=\"x0\", shape=(3,), dims=(ALL_STATE_DIM,))\n", + " P0 = Parameter(name=\"P0\", shape=(3, 3), dims=(ALL_STATE_DIM, ALL_STATE_AUX_DIM))\n", + "\n", + " ar_params = Parameter(\n", + " name=\"ar_params\", shape=(3,), dims=(\"ar_lags\",), constraints=\"Stationary, please :)\"\n", + " )\n", + " sigma_x = Parameter(name=\"sigma_x\", shape=(1,), dims=(SHOCK_DIM,))\n", + " return x0, P0, ar_params, sigma_x\n", + "\n", + " def set_states(self):\n", + " # To get a name on the observed, we make an observed state\n", + " ts1 = State(name=\"ts1\", observed=True)\n", + "\n", + " # Since the three hidden states are lags of the data, i'll call them L1, L2 L3\n", + " L1 = State(name=\"L1.data\", observed=False)\n", + " L2 = State(name=\"L2.data\", observed=False)\n", + " L3 = State(name=\"L3.data\", observed=False)\n", + "\n", + " return ts1, L1, L2, L3\n", + "\n", + " def set_shocks(self):\n", + " # There is one shock, called the \"innovations\" in the literature, so i'll go with that\n", + " innovation = Shock(name=\"innovations\")\n", + " return innovation\n", + "\n", + " def set_coords(self):\n", + " # This function sets up the coords dictionary used by pm.Model. The parent class has a helper\n", + " # self.default_coords() that makes the coords that are always expected by a statespace model --\n", + " # stuff like state, shock, etc.\n", + "\n", + " # You need to give one Coord object per dimension used among the Parameter objects you made a\n", + "\n", + " default_coords = self.default_coords()\n", + " ar_coord = Coord(dimension=\"ar_lags\", labels=(1, 2, 3))\n", + " return *default_coords, ar_coord" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfae06f5", + "metadata": {}, + "outputs": [], + "source": [ + "ar3 = AutoRegressiveThree(mode=\"NUMBA\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27119d27", + "metadata": {}, + "outputs": [], + "source": [ + "data = np.random.normal(0, 1, size=(100, 10))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9163bda", + "metadata": {}, + "outputs": [], + "source": [ + "batched_data = data.reshape(10, 100, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ed64433", + "metadata": {}, + "outputs": [], + "source": [ + "batched_data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "217b812d", + "metadata": {}, + "outputs": [], + "source": [ + "# Not vectorized\n", + "with pm.Model(coords=ar3.coords) as pymc_mod:\n", + " x0 = pm.Deterministic(\"x0\", pt.zeros((3,)), dims=(\"state\"))\n", + " P0 = pm.Deterministic(\"P0\", pt.eye(3) * 10, dims=(\"state\", \"state_aux\"))\n", + " ar_params = pm.Normal(\"ar_params\", shape=(3,), dims=(\"state\"))\n", + "\n", + " sigma_x = pm.Exponential(\"sigma_x\", 1, shape=(1,), dims=(\"shock\"))\n", + "\n", + " ar3.build_statespace_graph(data=data[:, [1]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84ac6e46", + "metadata": {}, + "outputs": [], + "source": [ + "# Vectorized\n", + "with pm.Model(\n", + " coords=ar3.coords\n", + " | {\n", + " \"batch\": [\n", + " \"batch_1\",\n", + " \"batch_2\",\n", + " \"batch_3\",\n", + " \"batch_4\",\n", + " \"batch_5\",\n", + " \"batch_6\",\n", + " \"batch_7\",\n", + " \"batch_8\",\n", + " \"batch_9\",\n", + " \"batch_10\",\n", + " ]\n", + " }\n", + ") as pymc_mod:\n", + " x0 = pm.Deterministic(\"x0\", pt.zeros((10, 3)), dims=(\"batch\", \"state\"))\n", + " P0 = pm.Deterministic(\n", + " \"P0\", pt.tile(pt.eye(3) * 1, (10, 1, 1)), dims=(\"batch\", \"state\", \"state_aux\")\n", + " )\n", + " ar_params = pm.Normal(\"ar_params\", dims=(\"batch\", \"state\"))\n", + "\n", + " sigma_x = pm.Exponential(\"sigma_x\", 1, dims=(\"batch\", \"shock\"))\n", + "\n", + " ar3.build_statespace_graph(data=batched_data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed208bcd", + "metadata": {}, + "outputs": [], + "source": [ + "with pymc_mod:\n", + " idata = pm.sample(tune=20, draws=20, compile_kwargs={\"mode\": \"NUMBA\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c172bc3a", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.impulse_response_function(idata)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "674ae03e", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.ssm[\"design\"].eval().shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fd9532e", + "metadata": {}, + "outputs": [], + "source": [ + "pt.tile(ar3.ssm[\"design\"], (3, 1)).eval().shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24035cf0", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.sample_statespace_matrices(idata, matrix_names=[\"Z\"]).posterior_predictive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64940705", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.sample_filter_outputs(idata, filter_output_names=[\"smoothed_covariances\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92b4b9f3", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.forecast(idata, periods=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf2af58a", + "metadata": {}, + "outputs": [], + "source": [ + "post = ar3.sample_conditional_posterior(idata, mvn_method=\"cholesky\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c3800a9", + "metadata": {}, + "outputs": [], + "source": [ + "with pymc_mod:\n", + " prior = pm.sample_prior_predictive(compile_kwargs={\"mode\": \"NUMBA\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5857a994", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.sample_conditional_prior(prior, mvn_method=\"cholesky\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b8969d2", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.sample_unconditional_prior(prior, mvn_method=\"cholesky\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "91d8c9aa", + "metadata": {}, + "outputs": [], + "source": [ + "unpost = ar3.sample_unconditional_posterior(idata, mvn_method=\"cholesky\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b112c551", + "metadata": {}, + "outputs": [], + "source": [ + "unpost" + ] + }, + { + "cell_type": "markdown", + "id": "174c0e9e", + "metadata": {}, + "source": [ + "# MVN Method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c456c29a", + "metadata": {}, + "outputs": [], + "source": [ + "class AutoRegressive3TwoSeries(PyMCStateSpace):\n", + " def __init__(self, mode: str):\n", + " k_states = 6 # 2 series × 3 lags\n", + " k_posdef = 2 # one innovation per series\n", + " k_endog = 2 # two observed series\n", + "\n", + " super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef, mode=mode)\n", + "\n", + " def make_symbolic_graph(self):\n", + " x0 = self.make_and_register_variable(\"x0\", shape=(6,))\n", + " P0 = self.make_and_register_variable(\"P0\", shape=(6, 6))\n", + "\n", + " ar_params = self.make_and_register_variable(\"ar_params\", shape=(2, 3))\n", + " sigma_x = self.make_and_register_variable(\"sigma_x\", shape=(2,))\n", + "\n", + " T = np.eye(6, k=-1)\n", + "\n", + " self.ssm[\"transition\", :, :] = T\n", + " self.ssm[\"transition\", 0, 0:3] = ar_params[0]\n", + " self.ssm[\"transition\", 3, 3:6] = ar_params[1]\n", + "\n", + " self.ssm[\"selection\", 0, 0] = 1\n", + " self.ssm[\"selection\", 3, 1] = 1\n", + "\n", + " self.ssm[\"state_cov\", :, :] = pt.diag(sigma_x)\n", + "\n", + " Z = np.zeros((2, 6))\n", + " Z[0, 0] = 1\n", + " Z[1, 3] = 1\n", + " self.ssm[\"design\", :, :] = Z\n", + "\n", + " self.ssm[\"initial_state\", :] = x0\n", + " self.ssm[\"initial_state_cov\", :, :] = P0\n", + "\n", + " def set_parameters(self):\n", + " x0 = Parameter(name=\"x0\", shape=(6,), dims=(ALL_STATE_DIM,))\n", + " P0 = Parameter(name=\"P0\", shape=(6, 6), dims=(ALL_STATE_DIM, ALL_STATE_AUX_DIM))\n", + "\n", + " ar_params = Parameter(\n", + " name=\"ar_params\",\n", + " shape=(2, 3),\n", + " dims=(\"observed_state\", \"ar_lags\"),\n", + " )\n", + "\n", + " sigma_x = Parameter(\n", + " name=\"sigma_x\",\n", + " shape=(2,),\n", + " dims=(\"observed_state\",),\n", + " )\n", + "\n", + " return x0, P0, ar_params, sigma_x\n", + "\n", + " def set_states(self):\n", + " # Observed states\n", + " ts1 = State(name=\"ts1\", observed=True)\n", + " ts2 = State(name=\"ts2\", observed=True)\n", + "\n", + " # Series 1 states\n", + " L1_s1 = State(name=\"L1.ts1\", observed=False)\n", + " L2_s1 = State(name=\"L2.ts1\", observed=False)\n", + " L3_s1 = State(name=\"L3.ts1\", observed=False)\n", + "\n", + " # Series 2 states\n", + " L1_s2 = State(name=\"L1.ts2\", observed=False)\n", + " L2_s2 = State(name=\"L2.ts2\", observed=False)\n", + " L3_s2 = State(name=\"L3.ts2\", observed=False)\n", + "\n", + " return (\n", + " ts1,\n", + " ts2,\n", + " L1_s1,\n", + " L2_s1,\n", + " L3_s1,\n", + " L1_s2,\n", + " L2_s2,\n", + " L3_s2,\n", + " )\n", + "\n", + " def set_shocks(self):\n", + " eps1 = Shock(name=\"innovation.ts1\")\n", + " eps2 = Shock(name=\"innovation.ts2\")\n", + " return eps1, eps2\n", + "\n", + " def set_coords(self):\n", + " default_coords = self.default_coords()\n", + " ar_coord = Coord(dimension=\"ar_lags\", labels=(1, 2, 3))\n", + " return *default_coords, ar_coord" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26f6673b", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.coords" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8759374", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.ssm[\"design\"].eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c34976d9", + "metadata": {}, + "outputs": [], + "source": [ + "ar3 = AutoRegressive3TwoSeries(mode=\"NUMBA\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f66a14c9", + "metadata": {}, + "outputs": [], + "source": [ + "with pm.Model(coords=ar3.coords) as pymc_mod:\n", + " x0 = pm.Deterministic(\"x0\", pt.zeros(6), dims=[\"state\"])\n", + " P0 = pm.Deterministic(\"P0\", pt.eye(6) * 10, dims=[\"state\", \"state_aux\"])\n", + "\n", + " # global mean per lag\n", + " rho_global = pm.Normal(\"rho_global\", 0.0, 0.5, dims=[\"ar_lags\"])\n", + " tau = pm.Exponential(\"tau\", 2.0, dims=[\"ar_lags\"])\n", + "\n", + " ar_offset = pm.Normal(\"ar_offset\", 0.0, 1.0, dims=[\"observed_state\", \"ar_lags\"])\n", + "\n", + " ar_params = pm.Deterministic(\n", + " \"ar_params\",\n", + " rho_global + tau * ar_offset,\n", + " dims=[\"observed_state\", \"ar_lags\"],\n", + " )\n", + "\n", + " sigma_x = pm.Exponential(\"sigma_x\", 1.0, dims=[\"observed_state\"])\n", + "\n", + " ar3.build_statespace_graph(data=data)\n", + " idata = pm.sample(compile_kwargs={\"mode\": \"NUMBA\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c8ea114", + "metadata": {}, + "outputs": [], + "source": [ + "pymc_mod.to_graphviz()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-extras", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 8e838b648..cc1df1838 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -14,7 +14,7 @@ from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import RandomState from pytensor.graph.basic import Variable -from pytensor.graph.replace import graph_replace +from pytensor.graph.replace import graph_replace, vectorize_graph from pytensor.graph.traversal import explicit_graph_inputs from rich.box import SIMPLE_HEAD from rich.console import Console @@ -49,9 +49,11 @@ SequenceMvNormal, ) from pymc_extras.statespace.filters.utilities import stabilize +from pymc_extras.statespace.utils.batch_tools import bmv from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, ALL_STATE_DIM, + BATCH_DIM, FILTER_OUTPUT_DIMS, FILTER_OUTPUT_TYPES, JITTER_DEFAULT, @@ -102,6 +104,62 @@ def _validate_property(props, property_name, expected_type): ) +def _infer_batch_dimensions( + core_matrices: list[pt.TensorVariable], + subbed_matrices: list[pt.TensorVariable], +) -> tuple[int | None, ...]: + inferred_batch_dims = () + + for core_matrix, sub_matrix in zip(core_matrices, subbed_matrices): + core_shape = core_matrix.type.shape + sub_shape = sub_matrix.type.shape + + if len(sub_shape) < len(core_shape): + raise ValueError( + f"Subbed matrix has fewer dims than core matrix: " f"{sub_shape} vs {core_shape}" + ) + + # Verify trailing/core dimensions match + trailing_shape = sub_shape[-len(core_shape) :] + + for core_dim, sub_dim in zip(core_shape, trailing_shape): + if core_dim is not None and sub_dim is not None and core_dim != sub_dim: + raise ValueError(f"Core dimension mismatch: " f"{core_shape} vs {sub_shape}") + + batch_dims = sub_shape[: -len(core_shape)] + + # Skip matrices with no batch dimensions + if len(batch_dims) == 0: + continue + + # First batched tensor establishes the batch shape + if len(inferred_batch_dims) == 0: + inferred_batch_dims = batch_dims + continue + + # Validate consistency + if len(batch_dims) != len(inferred_batch_dims): + raise ValueError(f"Inconsistent batch rank: " f"{batch_dims} vs {inferred_batch_dims}") + + merged_dims = [] + + for inferred_dim, new_dim in zip(inferred_batch_dims, batch_dims): + if inferred_dim is None: + merged_dims.append(new_dim) + elif new_dim is None: + merged_dims.append(inferred_dim) + elif inferred_dim == new_dim: + merged_dims.append(inferred_dim) + else: + raise ValueError( + f"Inconsistent batch dimensions: " f"{batch_dims} vs {inferred_batch_dims}" + ) + + inferred_batch_dims = tuple(merged_dims) + + return inferred_batch_dims + + class PyMCStateSpace: r""" Base class for Linear Gaussian Statespace models in PyMC. @@ -874,7 +932,7 @@ def _insert_random_variables(self): matrices = list(self._unpack_statespace_with_placeholders()) replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()} - self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True) + self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict) def _insert_data_variables(self): """ @@ -1011,6 +1069,121 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari dims = tuple([dim if dim in coords.keys() else None for dim in dim_names]) pm.Deterministic(name, var, dims=dims) + def _maybe_tv(self, name, base_sig, time_varying_names): + """Add leading time dimension if matrix is time varying.""" + if name in time_varying_names: + return f"(t,{base_sig})" + return f"({base_sig})" + + def _build_signature(self, inputs, outputs, time_varying_names): + def resolve(spec): + name, base_sig = spec + + # Non-time-varying literal signature + if name is None: + return f"({base_sig})" + + return self._maybe_tv(name, base_sig, time_varying_names) + + input_sig = ",".join(resolve(spec) for spec in inputs) + output_sig = ",".join(f"({sig})" for sig in outputs) + + return f"{input_sig}->{output_sig}" + + def _vectorize(self, fn, inputs, outputs, time_varying_names): + signature = self._build_signature( + inputs, + outputs, + time_varying_names, + ) + + return pt.vectorize(fn, signature=signature) + + def make_vectorized_filter( + self, + missing_fill_value, + cov_jitter, + time_varying_names, + ): + def vectorize_filter(data, x0, P0, c, d, T, Z, R, H, Q): + return self.kalman_filter.build_graph( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + missing_fill_value, + cov_jitter, + time_varying_names, + ) + + inputs = [ + (None, "t,o"), # data + (None, "k"), # x0 + (None, "k,k"), # P0 + ("c", "k"), + ("d", "o"), + ("T", "k,k"), + ("Z", "o,k"), + ("R", "k,r"), + ("H", "o,o"), + ("Q", "r,r"), + ] + + outputs = [ + "t,k", + "t,k", + "t,o", + "t,k,k", + "t,k,k", + "t,o,o", + "t", + ] + + return self._vectorize( + vectorize_filter, + inputs, + outputs, + time_varying_names, + ) + + def make_vectorized_smoother(self, cov_jitter, time_varying_names): + def vectorize_smoother(T, R, Q, filtered_states, filtered_covariances): + return self.kalman_smoother.build_graph( + T, + R, + Q, + filtered_states, + filtered_covariances, + cov_jitter, + time_varying_names, + ) + + inputs = [ + ("T", "k,k"), + ("R", "k,r"), + ("Q", "r,r"), + (None, "t,k"), + (None, "t,k,k"), + ] + + outputs = [ + "t,k", + "t,k,k", + ] + + return self._vectorize( + vectorize_smoother, + inputs, + outputs, + time_varying_names, + ) + def build_statespace_graph( self, data: np.ndarray | pd.DataFrame | pt.TensorVariable, @@ -1092,32 +1265,53 @@ def build_statespace_graph( self.mode = mode pm_mod = modelcontext(None) - self._insert_random_variables() self._save_exogenous_data_info() self._insert_data_variables() + self.batch_size = _infer_batch_dimensions( + core_matrices=self._unpack_statespace_with_placeholders(), + subbed_matrices=self.unpack_statespace(), + ) + obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) self._fit_data = data + data_dims = None + + if self.batch_size: + data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) + data, nan_mask = register_data_with_pymc( data, n_obs=self.ssm.k_endog, obs_coords=obs_coords, register_data=register_data, missing_fill_value=missing_fill_value, + data_dims=data_dims, ) # Order is important here: only call _insert_data_shape_into_n_timesteps after data has been registered. self._insert_data_shape_into_n_timesteps(data) - filter_outputs = self.kalman_filter.build_graph( - pt.as_tensor_variable(data), - *self.unpack_statespace(), - missing_fill_value=missing_fill_value, - cov_jitter=cov_jitter, - time_varying_names=self.ssm.time_varying_names, - ) + if data_dims and BATCH_DIM in data_dims: + vectorized_filter = self.make_vectorized_filter( + missing_fill_value=missing_fill_value, + cov_jitter=cov_jitter, + time_varying_names=self.ssm.time_varying_names, + ) + + filter_outputs = vectorized_filter( + pt.as_tensor_variable(data), *self.unpack_statespace() + ) + else: + filter_outputs = self.kalman_filter.build_graph( + pt.as_tensor_variable(data), + *self.unpack_statespace(), + missing_fill_value=missing_fill_value, + cov_jitter=cov_jitter, + time_varying_names=self.ssm.time_varying_names, + ) logp = filter_outputs.pop(-1) states, covs = filter_outputs[:3], filter_outputs[3:] @@ -1133,6 +1327,9 @@ def build_statespace_graph( obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"] obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None + if data_dims and BATCH_DIM in data_dims: + obs_dims = (BATCH_DIM, *obs_dims) + SequenceMvNormal( "obs", mus=observed_states, @@ -1216,6 +1413,8 @@ def _build_dummy_graph(self) -> None: def infer_variable_shape(name): shape = self._name_to_variable[name].type.shape + shape = (*self.batch_size, *shape) + if not any(dim is None for dim in shape): return shape @@ -1299,6 +1498,9 @@ def _kalman_filter_outputs_from_dummy_graph( obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) + if self.batch_size and data_dims is None: + data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) + data, nan_mask = register_data_with_pymc( data, n_obs=self.ssm.k_endog, @@ -1307,19 +1509,37 @@ def _kalman_filter_outputs_from_dummy_graph( register_data=True, ) - filter_outputs = self.kalman_filter.build_graph( - data, - x0, - P0, - c, - d, - T, - Z, - R, - H, - Q, - time_varying_names=self.ssm.time_varying_names, - ) + if data_dims and BATCH_DIM in data_dims: + vectorized_filter = self.make_vectorized_filter( + None, None, time_varying_names=self.ssm.time_varying_names + ) + + filter_outputs = vectorized_filter( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + ) + else: + filter_outputs = self.kalman_filter.build_graph( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + time_varying_names=self.ssm.time_varying_names, + ) filter_outputs.pop(-1) states, covariances = filter_outputs[:3], filter_outputs[3:] @@ -1327,14 +1547,24 @@ def _kalman_filter_outputs_from_dummy_graph( filtered_states, predicted_states, _ = states filtered_covariances, predicted_covariances, _ = covariances - [smoothed_states, smoothed_covariances] = self.kalman_smoother.build_graph( - T, - R, - Q, - filtered_states, - filtered_covariances, - time_varying_names=self.ssm.time_varying_names, - ) + if data_dims and BATCH_DIM in data_dims: + vectorized_smoother = self.make_vectorized_smoother(1e-8, self.ssm.time_varying_names) + [smoothed_states, smoothed_covariances] = vectorized_smoother( + T, + R, + Q, + filtered_states, + filtered_covariances, + ) + else: + [smoothed_states, smoothed_covariances] = self.kalman_smoother.build_graph( + T, + R, + Q, + filtered_states, + filtered_covariances, + time_varying_names=self.ssm.time_varying_names, + ) grouped_outputs = [ (filtered_states, filtered_covariances), @@ -1417,16 +1647,38 @@ def _sample_conditional( for name, (mu, cov) in zip(FILTER_OUTPUT_TYPES, grouped_outputs): dummy_ll = pt.zeros_like(mu) - state_dims = ( - (TIME_DIM, ALL_STATE_DIM) - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM]]) - else (None, None) - ) - obs_dims = ( - (TIME_DIM, OBS_STATE_DIM) - if all([dim in self._fit_coords for dim in [TIME_DIM, OBS_STATE_DIM]]) - else (None, None) - ) + if self.batch_size: + state_dims = ( + (BATCH_DIM, TIME_DIM, ALL_STATE_DIM) + if all( + [ + dim in self._fit_coords + for dim in [BATCH_DIM, TIME_DIM, ALL_STATE_DIM] + ] + ) + else (None, None, None) + ) + obs_dims = ( + (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) + if all( + [ + dim in self._fit_coords + for dim in [BATCH_DIM, TIME_DIM, OBS_STATE_DIM] + ] + ) + else (None, None, None) + ) + else: + state_dims = ( + (TIME_DIM, ALL_STATE_DIM) + if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM]]) + else (None, None) + ) + obs_dims = ( + (TIME_DIM, OBS_STATE_DIM) + if all([dim in self._fit_coords for dim in [TIME_DIM, OBS_STATE_DIM]]) + else (None, None) + ) SequenceMvNormal( f"{name}_{group}", @@ -1438,7 +1690,7 @@ def _sample_conditional( ) obs_mu = d + (Z @ mu[..., None]).squeeze(-1) - obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H + obs_cov = Z @ cov @ pt.swapaxes(Z, -2, -1) + H[..., None, :, :] SequenceMvNormal( f"{name}_{group}_observed", @@ -1551,8 +1803,18 @@ def _sample_unconditional( else: steps = len(temp_coords[TIME_DIM]) - 1 - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]]): - dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + if self.batch_size: + if all( + [ + dim in self._fit_coords + for dim in [BATCH_DIM, TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + ] + ): + dims = [BATCH_DIM, TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + + else: + if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] with pm.Model(coords=temp_coords if dims is not None else None) as forward_model: self._build_dummy_graph() @@ -1873,7 +2135,19 @@ def sample_statespace_matrices( long_name = SHORT_NAME_TO_LONG[short_name] if (long_name in matrix_names) or (short_name in matrix_names): name = long_name if long_name in matrix_names else short_name - dims = [x if x in self._fit_coords else None for x in MATRIX_DIMS[short_name]] + if self.batch_size: + dims = [ + x if x in self._fit_coords else None for x in MATRIX_DIMS[short_name] + ] + dims = [BATCH_DIM, *dims] + if ( + matrix.type.ndim != len(dims) + ): # This is necessary because vectorize_graph() does not add a batch dim to every matrix + matrix = pt.tile(matrix, (*self.batch_size, 1, 1)) + else: + dims = [ + x if x in self._fit_coords else None for x in MATRIX_DIMS[short_name] + ] pm.Deterministic(name, matrix, dims=dims) # TODO: Remove this after pm.Flat has its initial_value fixed @@ -1915,6 +2189,10 @@ def sample_filter_outputs( compile_kwargs = kwargs.pop("compile_kwargs", {}) compile_kwargs.setdefault("mode", self.mode) + data_dims = None + if self.batch_size: + data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) + with pm.Model(coords=self.coords) as m: self._build_dummy_graph() self._insert_random_variables() @@ -1935,35 +2213,85 @@ def sample_filter_outputs( n_obs=self.ssm.k_endog, obs_coords=obs_coords, register_data=True, + data_dims=data_dims, ) - filter_outputs = self.kalman_filter.build_graph( - data, - x0, - P0, - c, - d, - T, - Z, - R, - H, - Q, - time_varying_names=self.ssm.time_varying_names, - ) + if data_dims and BATCH_DIM in data_dims: + vectorized_filter = self.make_vectorized_filter( + None, None, time_varying_names=self.ssm.time_varying_names + ) - smoother_outputs = self.kalman_smoother.build_graph( - T, - R, - Q, - filter_outputs[0], - filter_outputs[3], - time_varying_names=self.ssm.time_varying_names, - ) + filter_outputs = vectorized_filter( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + ) + else: + filter_outputs = self.kalman_filter.build_graph( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + time_varying_names=self.ssm.time_varying_names, + ) + + if data_dims and BATCH_DIM in data_dims: + vectorized_smoother = self.make_vectorized_smoother( + 1e-8, self.ssm.time_varying_names + ) + smoother_outputs = vectorized_smoother( + T, + R, + Q, + filter_outputs[0], + filter_outputs[3], + ) + else: + smoother_outputs = self.kalman_smoother.build_graph( + T, + R, + Q, + filter_outputs[0], + filter_outputs[3], + time_varying_names=self.ssm.time_varying_names, + ) filter_outputs = filter_outputs[:-1] + list(smoother_outputs) + + if self.batch_size: + ordered_filter_output_names = [ + "filtered_states", + "predicted_states", + "predicted_observed_states", + "filtered_covariances", + "predicted_covariances", + "predicted_observed_covariances", + "smoothed_states", + "smoothed_covariances", + ] + # Names disappear when we vectorize need to reassign names + for output, name in zip(filter_outputs, ordered_filter_output_names): + output.name = name + for output in filter_outputs: if output.name in filter_output_names: - dims = FILTER_OUTPUT_DIMS[output.name] + if self.batch_size: + dims = (BATCH_DIM, *FILTER_OUTPUT_DIMS[output.name]) + else: + dims = FILTER_OUTPUT_DIMS[output.name] pm.Deterministic(output.name, output, dims=dims) with freeze_dims_and_data(m): @@ -2349,23 +2677,39 @@ def _build_forecast_model( filter_time_dim = TIME_DIM temp_coords = self._fit_coords.copy() - dims = None - if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): - dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] - t0_idx = np.flatnonzero(time_index == t0)[0] + idx = (slice(None), t0_idx) if self.batch_size else (t0_idx,) temp_coords["data_time"] = time_index temp_coords[TIME_DIM] = forecast_index - mu_dims, cov_dims = None, None - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): - mu_dims = ["data_time", ALL_STATE_DIM] - cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] + dims, mu_dims, cov_dims = None, None, None + + if self.batch_size: + data_dims = [BATCH_DIM, "data_time", OBS_STATE_DIM] + if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [BATCH_DIM, TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + if all( + [ + dim in self._fit_coords + for dim in [BATCH_DIM, TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM] + ] + ): + mu_dims = [BATCH_DIM, ALL_STATE_DIM] + cov_dims = [BATCH_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM] + else: + data_dims = ["data_time", OBS_STATE_DIM] + if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + if all( + [dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]] + ): + mu_dims = [ALL_STATE_DIM] + cov_dims = [ALL_STATE_DIM, ALL_STATE_AUX_DIM] with pm.Model(coords=temp_coords) as forecast_model: _, grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( - data_dims=["data_time", OBS_STATE_DIM], + data_dims=data_dims, ) group_idx = FILTER_OUTPUT_TYPES.index(filter_output) @@ -2385,10 +2729,10 @@ def _build_forecast_model( mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) x0 = pm.Deterministic( - "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None + "x0_slice", mu_frozen[idx], dims=mu_dims if mu_dims is not None else None ) P0 = pm.Deterministic( - "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None + "P0_slice", cov_frozen[idx], dims=cov_dims if cov_dims is not None else None ) # Get fresh matrices with n_timesteps placeholder still intact. @@ -2741,7 +3085,14 @@ def impulse_response_function( matrices = self._insert_constant_timestep(self.unpack_statespace(), step=n_steps) P0, _, c, d, T, Z, R, H, post_Q = matrices - x0 = pm.Deterministic("x0_new", pt.zeros(self.k_states), dims=[ALL_STATE_DIM]) + if self.batch_size: + x0 = pm.Deterministic( + "x0_new", + pt.zeros((*self.batch_size, self.k_states)), + dims=[BATCH_DIM, ALL_STATE_DIM], + ) + else: + x0 = pm.Deterministic("x0_new", pt.zeros(self.k_states), dims=[ALL_STATE_DIM]) if use_posterior_cov: Q = post_Q @@ -2753,31 +3104,53 @@ def impulse_response_function( Q = pt.linalg.cholesky(Q) / pt.diag(Q) if shock_trajectory is None: - shock_trajectory = pt.zeros((n_steps, self.k_posdef)) + shock_trajectory = pt.zeros((*self.batch_size, n_steps, self.k_posdef)) if Q is not None: - init_shock = pm.MvNormal( - "initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method=mvn_method - ) + if self.batch_size: + init_shock = pm.MvNormal( + "initial_shock", + mu=0, + cov=Q, + dims=[BATCH_DIM, SHOCK_DIM], + method=mvn_method, + ) + else: + init_shock = pm.MvNormal( + "initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method=mvn_method + ) else: - init_shock = pm.Deterministic( - "initial_shock", - pt.as_tensor_variable(np.atleast_1d(shock_size)), - dims=[SHOCK_DIM], - ) - shock_trajectory = pt.set_subtensor(shock_trajectory[0], init_shock) + if self.batch_size: + init_shock = pm.Deterministic( + "initial_shock", + pt.as_tensor_variable(np.atleast_1d(shock_size)), + dims=[BATCH_DIM, SHOCK_DIM], + ) + else: + init_shock = pm.Deterministic( + "initial_shock", + pt.as_tensor_variable(np.atleast_1d(shock_size)), + dims=[SHOCK_DIM], + ) + if self.batch_size: + shock_trajectory = pt.set_subtensor(shock_trajectory[:, 0], init_shock) + else: + shock_trajectory = pt.set_subtensor(shock_trajectory[0], init_shock) else: shock_trajectory = pt.as_tensor_variable(shock_trajectory) time_varying_T = "transition" in self.ssm.time_varying_names + if self.batch_size: + shock_trajectory = shock_trajectory.swapaxes(0, 1) + def irf_step(*args): if time_varying_T: shock, T, x, c, R = args else: shock, x, c, T, R = args - next_x = c + T @ x + R @ shock + next_x = c + bmv(T, x) + bmv(R, shock) return next_x sequences = [shock_trajectory, T] if time_varying_T else [shock_trajectory] @@ -2793,7 +3166,13 @@ def irf_step(*args): return_updates=False, ) - pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM]) + if self.batch_size: + irf = irf.swapaxes(0, 1) + + if self.batch_size: + pm.Deterministic("irf", irf, dims=[BATCH_DIM, TIME_DIM, ALL_STATE_DIM]) + else: + pm.Deterministic("irf", irf, dims=[TIME_DIM, ALL_STATE_DIM]) irf_idata = pm.sample_posterior_predictive( idata, diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index ed98ee977..7c549647c 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -10,6 +10,8 @@ from pytensor.graph.basic import Node from pytensor.tensor.random import multivariate_normal +from pymc_extras.statespace.utils.batch_tools import bmv + floatX = pytensor.config.floatX COV_ZERO_TOL = 0 @@ -203,8 +205,8 @@ def step_fn(*args): for src_idx, dst_idx in enumerate(non_seq_positions): ordered[dst_idx] = non_seqs[src_idx] c, d, T, Z, R, H, Q = ordered - k = T.shape[0] - a = state[:k] + k = T.shape[-1] + a = state[..., :k] middle_rng, a_innovation = pm.MvNormal.dist( mu=0, cov=Q, rng=rng, method=method, return_next_rng=True @@ -213,13 +215,13 @@ def step_fn(*args): mu=0, cov=H, rng=middle_rng, method=method, return_next_rng=True ) - a_mu = c + T @ a - a_next = a_mu + R @ a_innovation + a_mu = c + bmv(T, a) + a_next = a_mu + bmv(R, a_innovation) - y_mu = d + Z @ a_next + y_mu = d + bmv(Z, a_next) y_next = y_mu + y_innovation - next_state = pt.concatenate([a_next, y_next], axis=0) + next_state = pt.concatenate([a_next, y_next], axis=-1) return next_rng, next_state @@ -227,9 +229,9 @@ def step_fn(*args): H_init = H_ if H_ in non_sequences else H_[0] init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method=method) - init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method=method) + init_y_ = pm.MvNormal.dist(bmv(Z_init, init_x_), H_init, rng=rng, method=method) - init_dist_ = pt.concatenate([init_x_, init_y_], axis=0) + init_dist_ = pt.concatenate([init_x_, init_y_], axis=-1) ss_rng, statespace = pytensor.scan( step_fn, @@ -242,11 +244,12 @@ def step_fn(*args): ) if append_x0: - statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0) - statespace_ = pt.specify_shape(statespace_, (steps + 1, None)) + init_dist_expanded = pt.expand_dims(init_dist_, axis=0) + statespace_ = pt.concatenate([init_dist_expanded, statespace], axis=0) + # statespace_ = pt.specify_shape(statespace_, (steps + 1, None)) else: statespace_ = statespace - statespace_ = pt.specify_shape(statespace_, (steps, None)) + # statespace_ = pt.specify_shape(statespace_, (steps, None)) linear_gaussian_ss_op = LinearGaussianStateSpaceRV( inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng], @@ -287,7 +290,15 @@ def __new__( dims = kwargs.pop("dims", None) latent_dims = None obs_dims = None - if dims is not None: + if dims is not None and len(dims) > 3: + # if len(dims) != 3: + # ValueError( + # "LinearGaussianStateSpace expects 3 dims: time, all_states, and observed_states" + # ) + batch_dim, time_dim, state_dim, obs_dim = dims + latent_dims = [time_dim, batch_dim, state_dim] + obs_dims = [time_dim, batch_dim, obs_dim] + elif dims is not None: if len(dims) != 3: ValueError( "LinearGaussianStateSpace expects 3 dims: time, all_states, and observed_states" @@ -313,7 +324,7 @@ def __new__( method=method, **kwargs, ) - latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None)) + # latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None)) if k_endog is None: k_endog = cls._get_k_endog(H) latent_slice = slice(None, -k_endog) diff --git a/pymc_extras/statespace/models/ETS.py b/pymc_extras/statespace/models/ETS.py index dbe6dfcfa..9e1d26294 100644 --- a/pymc_extras/statespace/models/ETS.py +++ b/pymc_extras/statespace/models/ETS.py @@ -289,7 +289,7 @@ def __init__( k_endog, k_states, k_posdef, - filter_type, + filter_type=filter_type, verbose=verbose, measurement_error=measurement_error, mode=mode, diff --git a/pymc_extras/statespace/models/SARIMAX.py b/pymc_extras/statespace/models/SARIMAX.py index 11c0c6850..aac51505b 100644 --- a/pymc_extras/statespace/models/SARIMAX.py +++ b/pymc_extras/statespace/models/SARIMAX.py @@ -265,7 +265,7 @@ def __init__( k_endog, k_states, k_posdef, - filter_type, + filter_type=filter_type, verbose=verbose, measurement_error=measurement_error, mode=mode, diff --git a/pymc_extras/statespace/models/VARMAX.py b/pymc_extras/statespace/models/VARMAX.py index 7677b092c..cab85fa72 100644 --- a/pymc_extras/statespace/models/VARMAX.py +++ b/pymc_extras/statespace/models/VARMAX.py @@ -198,7 +198,7 @@ def __init__( k_endog, k_states, k_posdef, - filter_type, + filter_type=filter_type, verbose=verbose, measurement_error=measurement_error, mode=mode, diff --git a/pymc_extras/statespace/utils/batch_tools.py b/pymc_extras/statespace/utils/batch_tools.py new file mode 100644 index 000000000..49b8a70d9 --- /dev/null +++ b/pymc_extras/statespace/utils/batch_tools.py @@ -0,0 +1,5 @@ +import pytensor.tensor as pt + + +def bmv(A, x): + return pt.matmul(A, x[..., None])[..., 0] diff --git a/pymc_extras/statespace/utils/constants.py b/pymc_extras/statespace/utils/constants.py index 8c7b399fe..cdbfff9d7 100644 --- a/pymc_extras/statespace/utils/constants.py +++ b/pymc_extras/statespace/utils/constants.py @@ -15,6 +15,7 @@ FACTOR_DIM = "factor" ERROR_AR_PARAM_DIM = "error_lag_ar" EXOG_STATE_DIM = "exogenous" +BATCH_DIM = "batch" MISSING_FILL = -9999.0 JITTER_DEFAULT = 1e-8 if pytensor.config.floatX.endswith("64") else 1e-6 diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index 467ae314f..8abc1c335 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -10,11 +10,7 @@ from pymc.exceptions import ImputationWarning from pytensor.tensor.sharedvar import TensorSharedVariable -from pymc_extras.statespace.utils.constants import ( - MISSING_FILL, - OBS_STATE_DIM, - TIME_DIM, -) +from pymc_extras.statespace.utils.constants import BATCH_DIM, MISSING_FILL, OBS_STATE_DIM, TIME_DIM NO_TIME_INDEX_WARNING = ( "No time index found on the supplied data. A simple range index will be automatically " @@ -36,11 +32,13 @@ def get_data_dims(data): return data_dims -def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None): +def _validate_data_shape( + data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None, batched=False +): if col_names is None: col_names = [] - if len(data_shape) != 2: + if not batched and len(data_shape) != 2: raise ValueError("Data must be a 2d matrix") if data_shape[-1] != n_obs: @@ -59,22 +57,27 @@ def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=Fal ) -def preprocess_tensor_data(data, n_obs, obs_coords=None): +def preprocess_tensor_data(data, n_obs, obs_coords=None, batched=False): data_shape = data.shape.eval() - _validate_data_shape(data_shape, n_obs, obs_coords) + _validate_data_shape(data_shape, n_obs, obs_coords, batched=batched) if obs_coords is not None: warnings.warn(NO_TIME_INDEX_WARNING) - index = np.arange(data_shape[0], dtype="int") + + index = ( + np.arange(data_shape[0], dtype="int") + if not batched + else np.arange(data_shape[1], dtype="int") + ) return data.eval(), index -def preprocess_numpy_data(data, n_obs, obs_coords=None): - _validate_data_shape(data.shape, n_obs, obs_coords) +def preprocess_numpy_data(data, n_obs, obs_coords=None, batched=False): + _validate_data_shape(data.shape, n_obs, obs_coords, batched=batched) if obs_coords is not None: warnings.warn(NO_TIME_INDEX_WARNING) - index = np.arange(data.shape[0], dtype="int") + index = np.arange(data.shape[0], dtype="int") if not batched else np.arange(data.shape[1]) return data, index @@ -122,11 +125,15 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals return preprocess_numpy_data(data.values, n_obs, obs_coords) -def add_data_to_active_model(values, index, data_dims=None): +def add_data_to_active_model(values, index, data_dims=None, batched=False): pymc_mod = modelcontext(None) if data_dims is None: - data_dims = [TIME_DIM, OBS_STATE_DIM] - time_dim = data_dims[0] + if not batched: + data_dims = [TIME_DIM, OBS_STATE_DIM] + else: + data_dims = [BATCH_DIM, TIME_DIM, OBS_STATE_DIM] + + time_dim = data_dims[0] if not batched else data_dims[1] if isinstance(index, pd.Index): index = index.rename(time_dim) @@ -145,10 +152,14 @@ def add_data_to_active_model(values, index, data_dims=None): # If the data has just one column, we need to specify the shape as (None, 1), or else the JAX backend will # raise a broadcasting error. - if values.shape[-1] == 1 or values.ndim == 1: + if (values.shape[-1] == 1 or values.ndim == 1) and not batched: data_shape = (None, 1) - else: + elif (values.shape[-1] == 1 or values.ndim == 1) and batched: + data_shape = (values.shape[0], None, 1) + elif not batched: data_shape = (None, values.shape[-1]) + else: + data_shape = (values.shape[0], None, *values.shape[2:]) data = pm.Data("data", values, dims=data_dims, shape=data_shape) @@ -184,10 +195,14 @@ def mask_missing_values_in_data(values, missing_fill_value=None): def register_data_with_pymc( data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None ): + batched = False + if data_dims and BATCH_DIM in data_dims: + batched = True + if isinstance(data, pt.TensorVariable | TensorSharedVariable): - values, index = preprocess_tensor_data(data, n_obs, obs_coords) + values, index = preprocess_tensor_data(data, n_obs, obs_coords, batched) elif isinstance(data, np.ndarray): - values, index = preprocess_numpy_data(data, n_obs, obs_coords) + values, index = preprocess_numpy_data(data, n_obs, obs_coords, batched) elif isinstance(data, pd.DataFrame | pd.Series): values, index = preprocess_pandas_data(data, n_obs, obs_coords) else: @@ -196,7 +211,7 @@ def register_data_with_pymc( data, nan_mask = mask_missing_values_in_data(values, missing_fill_value) if register_data: - data = add_data_to_active_model(data, index, data_dims) + data = add_data_to_active_model(data, index, data_dims, batched) else: data = pytensor.shared(data, name="data") return data, nan_mask diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 406aaf9cb..6e1210d9e 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -1396,3 +1396,287 @@ def test_impulse_response_function(self, ss_mod_time_varying, idata_time_varying assert "irf" in result assert result["irf"].shape[2] == 20 assert not np.any(np.isnan(result["irf"].values)) + + +@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") +@pytest.mark.filterwarnings("ignore:No start date provided") +@pytest.mark.parametrize("batch_dims", [(2,), (5,), (7,)]) +def test_batch_ssm(batch_dims: tuple[int], rng) -> pm.Model: + mod = st.LevelTrend(order=2, innovations_order=[0, 1]) + mod += st.Autoregressive(name="ar", order=1) + mod += st.MeasurementError(name="obs") + ssm = mod.build(name="batch_ssm", mode="NUMBA") + + if len(batch_dims) > 1: + raise NotImplementedError + + batched_data = np.random.normal(0, 1, size=(*batch_dims, 100, 1)) + + with pm.Model( + coords=ssm.coords | {"batch": [f"batch_{i}" for i in range(batch_dims[0])]} + ) as pymc_mod: + P0 = pm.Deterministic( + "P0", pt.tile(pt.eye(3) * 1, (*batch_dims, 1, 1)), dims=("batch", "state", "state_aux") + ) + + initial_level_trend = pm.Normal("initial_level_trend", dims=("batch", "state_level_trend")) + params_ar = pm.Beta("params_ar", alpha=3, beta=3, dims=("batch", "lag_ar")) + + sigma_level_trend = pm.Gamma( + "sigma_level_trend", alpha=2, beta=50, dims=("batch", "shock_level_trend") + ) + sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=5, shape=(*batch_dims,)) + sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.05, shape=(*batch_dims,)) + + ssm.build_statespace_graph(data=batched_data) + + coord_dimensions = {k: len(v) for k, v in pymc_mod.coords.items()} + + with pymc_mod: + prior = pm.sample_prior_predictive(compile_kwargs={"mode": "NUMBA"}, random_seed=rng) + idata = pm.sample(tune=10, draws=10, compile_kwargs={"mode": "NUMBA"}, random_seed=rng) + + unconditional_prior = ssm.sample_unconditional_prior( + prior, mvn_method="cholesky", random_seed=rng + ) + conditional_prior = ssm.sample_conditional_prior(prior, mvn_method="cholesky", random_seed=rng) + + prior_chain = prior.prior.dims["chain"] + prior_draw = prior.prior.dims["draw"] + + assert prior.prior["params_ar"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["lag_ar"], + ) + assert prior.prior["sigma_ar"].shape == (prior_chain, prior_draw, coord_dimensions["batch"]) + assert prior.prior["initial_level_trend"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["state_level_trend"], + ) + assert prior.prior["sigma_level_trend"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["shock_level_trend"], + ) + assert prior.prior["sigma_obs"].shape == (prior_chain, prior_draw, coord_dimensions["batch"]) + assert prior.prior["P0"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["state"], + coord_dimensions["state_aux"], + ) + + assert conditional_prior["filtered_prior"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_prior["filtered_prior_observed"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + assert conditional_prior["predicted_prior"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_prior["predicted_prior_observed"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + assert conditional_prior["smoothed_prior"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_prior["smoothed_prior_observed"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + + assert unconditional_prior["prior_latent"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["time"], + coord_dimensions["batch"], + coord_dimensions["state"], + ) + assert unconditional_prior["prior_observed"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["time"], + coord_dimensions["batch"], + coord_dimensions["observed_state"], + ) + + posterior_chain = idata.posterior.dims["chain"] + posterior_draw = idata.posterior.dims["draw"] + + assert idata.posterior["params_ar"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["lag_ar"], + ) + assert idata.posterior["sigma_ar"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + ) + assert idata.posterior["initial_level_trend"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["state_level_trend"], + ) + assert idata.posterior["sigma_level_trend"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["shock_level_trend"], + ) + assert idata.posterior["sigma_obs"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + ) + assert idata.posterior["P0"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["state"], + coord_dimensions["state_aux"], + ) + + irf = ssm.impulse_response_function(idata, random_seed=rng) + irf_steps = irf.dims["time"] + assert irf.irf.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + irf_steps, + coord_dimensions["state"], + ) + + T_sample = ssm.sample_statespace_matrices(idata, matrix_names=["T"]) + assert T_sample.posterior_predictive.T.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["state"], + coord_dimensions["state_aux"], + ) + + filtered_covariance_sample = ssm.sample_filter_outputs( + idata, filter_output_names=["filtered_covariances"] + ) + assert filtered_covariance_sample.posterior_predictive.filtered_covariances.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + coord_dimensions["state_aux"], + ) + + forecast = ssm.forecast(idata, periods=10, random_seed=rng) + forecast_steps = forecast.dims["time"] + assert forecast.forecast_latent.shape == ( + posterior_chain, + posterior_draw, + forecast_steps, + coord_dimensions["batch"], + coord_dimensions["state"], + ) + assert forecast.forecast_observed.shape == ( + posterior_chain, + posterior_draw, + forecast_steps, + coord_dimensions["batch"], + coord_dimensions["observed_state"], + ) + + unconditional_post = ssm.sample_unconditional_posterior( + idata, mvn_method="cholesky", random_seed=rng + ) + assert unconditional_post.posterior_latent.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["time"], + coord_dimensions["batch"], + coord_dimensions["state"], + ) + assert unconditional_post.posterior_observed.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["time"], + coord_dimensions["batch"], + coord_dimensions["observed_state"], + ) + + conditional_post = ssm.sample_conditional_posterior( + idata, mvn_method="cholesky", random_seed=rng + ) + assert conditional_post["filtered_posterior"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_post["filtered_posterior_observed"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + assert conditional_post["predicted_posterior"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_post["predicted_posterior_observed"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + assert conditional_post["smoothed_posterior"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_post["smoothed_posterior_observed"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + )