From fe8a482b42be900234f49ba7210d6d0070ecebd4 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Wed, 20 May 2026 11:42:03 -0600 Subject: [PATCH 01/10] rough implementation of batched SSM compatability --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 2206 +++++++++++++++++ pymc_extras/statespace/core/statespace.py | 292 ++- .../statespace/filters/distributions.py | 38 +- pymc_extras/statespace/utils/constants.py | 1 + pymc_extras/statespace/utils/data_tools.py | 56 +- 5 files changed, 2515 insertions(+), 78 deletions(-) create mode 100644 notebooks/temporary_scratchpad_ssm_batch_dims.ipynb diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb new file mode 100644 index 000000000..5b25d958c --- /dev/null +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -0,0 +1,2206 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f9e8547b", + "metadata": {}, + "source": [ + "# Notes to self:\n", + "Currently, I am able to get the model to sample with batch dimensions and I am able to use the following internal methods:\n", + "\n", + "* `sample_unconditional_prior`\n", + "* `sample_conditional_prior`\n", + "* `sample_unconditional_posterior`\n", + "* `sample_conditional_posterior`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d95e629f", + "metadata": {}, + "outputs": [], + "source": [ + "import arviz as az\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pymc as pm\n", + "import pytensor.tensor as pt\n", + "\n", + "from pymc_extras.statespace.core.statespace import PyMCStateSpace\n", + "from pymc_extras.statespace.filters import StandardFilter, KalmanSmoother\n", + "\n", + "from pymc_extras.statespace.core.properties import (\n", + " Parameter,\n", + " State,\n", + " Shock,\n", + " Coord,\n", + ")\n", + "from pymc_extras.statespace.utils.constants import ALL_STATE_DIM, ALL_STATE_AUX_DIM, SHOCK_DIM" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "34b32aa3", + "metadata": {}, + "outputs": [], + "source": [ + "class AutoRegressiveThree(PyMCStateSpace):\n", + " def __init__(self, mode: str):\n", + " k_states = 3 # size of the state vector x\n", + " k_posdef = 1 # number of shocks (size of the state covariance matrix Q)\n", + " k_endog = 1 # number of observed states\n", + "\n", + " super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef, mode=mode)\n", + "\n", + " def make_symbolic_graph(self):\n", + " x0 = self.make_and_register_variable(\"x0\", shape=(3,))\n", + " P0 = self.make_and_register_variable(\"P0\", shape=(3, 3))\n", + "\n", + " ar_params = self.make_and_register_variable(\"ar_params\", shape=(3,))\n", + " sigma_x = self.make_and_register_variable(\"sigma_x\", shape=(1,))\n", + "\n", + " self.ssm[\"transition\", :, :] = np.eye(3, k=-1)\n", + " self.ssm[\"selection\", 0, 0] = 1\n", + " self.ssm[\"design\", 0, 0] = 1\n", + "\n", + " self.ssm[\"initial_state\", :] = x0\n", + " self.ssm[\"initial_state_cov\", :, :] = P0\n", + " self.ssm[\"transition\", 0, :] = ar_params\n", + " self.ssm[\"state_cov\", :, :] = sigma_x\n", + "\n", + " def set_parameters(self):\n", + " # Only the \"name\" parameter is required here. \"Shape\" is only used when printing the\n", + " # model requirements table. \"Dims\" are used to link variables to coords.\n", + " x0 = Parameter(name=\"x0\", shape=(3,), dims=(ALL_STATE_DIM,))\n", + " P0 = Parameter(name=\"P0\", shape=(3, 3), dims=(ALL_STATE_DIM, ALL_STATE_AUX_DIM))\n", + "\n", + " ar_params = Parameter(\n", + " name=\"ar_params\", shape=(3,), dims=(\"ar_lags\",), constraints=\"Stationary, please :)\"\n", + " )\n", + " sigma_x = Parameter(name=\"sigma_x\", shape=(1,), dims=(SHOCK_DIM,))\n", + " return x0, P0, ar_params, sigma_x\n", + "\n", + " def set_states(self):\n", + " # To get a name on the observed, we make an observed state\n", + " ts1 = State(name=\"ts1\", observed=True)\n", + "\n", + " # Since the three hidden states are lags of the data, i'll call them L1, L2 L3\n", + " L1 = State(name=\"L1.data\", observed=False)\n", + " L2 = State(name=\"L2.data\", observed=False)\n", + " L3 = State(name=\"L3.data\", observed=False)\n", + "\n", + " return ts1, L1, L2, L3\n", + "\n", + " def set_shocks(self):\n", + " # There is one shock, called the \"innovations\" in the literature, so i'll go with that\n", + " innovation = Shock(name=\"innovations\")\n", + " return innovation\n", + "\n", + " def set_coords(self):\n", + " # This function sets up the coords dictionary used by pm.Model. The parent class has a helper\n", + " # self.default_coords() that makes the coords that are always expected by a statespace model --\n", + " # stuff like state, shock, etc.\n", + "\n", + " # You need to give one Coord object per dimension used among the Parameter objects you made a\n", + "\n", + " default_coords = self.default_coords()\n", + " ar_coord = Coord(dimension=\"ar_lags\", labels=(1, 2, 3))\n", + " return *default_coords, ar_coord" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bfae06f5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
                          Model Requirements                           \n",
+       "                                                                       \n",
+       "  Variable    Shape    Constraints                         Dimensions  \n",
+       " ───────────────────────────────────────────────────────────────────── \n",
+       "  x0          (3,)                                         ('state',)  \n",
+       "  P0          (3, 3)                           ('state', 'state_aux')  \n",
+       "  ar_params   (3,)     Stationary, please :)             ('ar_lags',)  \n",
+       "  sigma_x     (1,)                                         ('shock',)  \n",
+       "                                                                       \n",
+       " These parameters should be assigned priors inside a PyMC model block  \n",
+       "           before calling the build_statespace_graph method.           \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Model Requirements \u001b[0m\n", + " \n", + " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", + " ───────────────────────────────────────────────────────────────────── \n", + " x0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m,\u001b[1m)\u001b[0m \n", + " P0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", + " ar_params \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m Stationary, please :\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'ar_lags'\u001b[0m,\u001b[1m)\u001b[0m \n", + " sigma_x \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'shock'\u001b[0m,\u001b[1m)\u001b[0m \n", + " \n", + "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block \u001b[0m\n", + "\u001b[2;3m before calling the build_statespace_graph method. \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "ar3 = AutoRegressiveThree(mode=\"NUMBA\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "27119d27", + "metadata": {}, + "outputs": [], + "source": [ + "data = np.random.normal(0, 1, size=(100, 2))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "d9163bda", + "metadata": {}, + "outputs": [], + "source": [ + "batched_data = data.reshape(2, 100, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "217b812d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:77: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", + " warnings.warn(NO_TIME_INDEX_WARNING)\n" + ] + } + ], + "source": [ + "# Not vectorized\n", + "with pm.Model(coords=ar3.coords) as pymc_mod:\n", + " x0 = pm.Deterministic(\"x0\", pt.zeros((3,)), dims=(\"state\"))\n", + " P0 = pm.Deterministic(\"P0\", pt.eye(3) * 10, dims=(\"state\", \"state_aux\"))\n", + " ar_params = pm.Normal(\"ar_params\", shape=(3,), dims=(\"state\"))\n", + "\n", + " sigma_x = pm.Exponential(\"sigma_x\", 1, shape=(1,), dims=(\"shock\"))\n", + "\n", + " ar3.build_statespace_graph(data=data[:, [1]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84ac6e46", + "metadata": {}, + "outputs": [], + "source": [ + "# Vectorized\n", + "with pm.Model(coords=ar3.coords | {\"batch\": [\"batch_1\", \"batch_2\"]}) as pymc_mod:\n", + " x0 = pm.Deterministic(\"x0\", pt.zeros((2, 3)), dims=(\"batch\", \"state\"))\n", + " P0 = pm.Deterministic(\n", + " \"P0\", pt.tile(pt.eye(3) * 1, (2, 1, 1)), dims=(\"batch\", \"state\", \"state_aux\")\n", + " )\n", + " ar_params = pm.Normal(\"ar_params\", dims=(\"batch\", \"state\"))\n", + "\n", + " sigma_x = pm.Exponential(\"sigma_x\", 1, dims=(\"batch\", \"shock\"))\n", + "\n", + " ar3.build_statespace_graph(data=batched_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ed208bcd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [ar_params, sigma_x]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "54b5be47bbc44f7399db826f05108517", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 2 seconds.\n"
+     ]
+    }
+   ],
+   "source": [
+    "with pymc_mod:\n",
+    "    idata = pm.sample(tune=200, draws=200, compile_kwargs={\"mode\": \"NUMBA\"})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "cf2af58a",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:77: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n",
+      "  warnings.warn(NO_TIME_INDEX_WARNING)\n",
+      "Sampling: [filtered_posterior, filtered_posterior_observed, predicted_posterior, predicted_posterior_observed, smoothed_posterior, smoothed_posterior_observed]\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "fe57e8f78975419887f133c358e24166",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "post = ar3.sample_conditional_posterior(idata, mvn_method=\"cholesky\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "id": "5c3800a9",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling: [ar_params, obs, sigma_x]\n"
+     ]
+    }
+   ],
+   "source": [
+    "with pymc_mod:\n",
+    "    prior = pm.sample_prior_predictive(compile_kwargs={\"mode\": \"NUMBA\"})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "5857a994",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:77: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n",
+      "  warnings.warn(NO_TIME_INDEX_WARNING)\n",
+      "Sampling: [filtered_prior, filtered_prior_observed, predicted_prior, predicted_prior_observed, smoothed_prior, smoothed_prior_observed]\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c47107700d01475fbc7f96b068ea0aec",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree 'posterior_predictive'>\n",
+       "Group: /posterior_predictive\n",
+       "    Dimensions:                   (chain: 1, draw: 500, time: 100, state: 3,\n",
+       "                                   observed_state: 1)\n",
+       "    Coordinates:\n",
+       "      * chain                     (chain) int64 8B 0\n",
+       "      * draw                      (draw) int64 4kB 0 1 2 3 4 ... 495 496 497 498 499\n",
+       "      * time                      (time) int64 800B 0 1 2 3 4 5 ... 95 96 97 98 99\n",
+       "      * state                     (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "      * observed_state            (observed_state) <U3 12B 'ts1'\n",
+       "    Data variables:\n",
+       "        filtered_prior            (chain, draw, time, state) float64 1MB -0.1217 ...\n",
+       "        filtered_prior_observed   (chain, draw, time, observed_state) float64 400kB ...\n",
+       "        predicted_prior           (chain, draw, time, state) float64 1MB -1.732 ....\n",
+       "        predicted_prior_observed  (chain, draw, time, observed_state) float64 400kB ...\n",
+       "        smoothed_prior            (chain, draw, time, state) float64 1MB -0.1217 ...\n",
+       "        smoothed_prior_observed   (chain, draw, time, observed_state) float64 400kB ...\n",
+       "    Attributes:\n",
+       "        created_at:                 2026-05-20T16:02:32.953527+00:00\n",
+       "        creation_library:           ArviZ\n",
+       "        creation_library_version:   1.1.0\n",
+       "        creation_library_language:  Python\n",
+       "        inference_library:          pymc\n",
+       "        inference_library_version:  6.0.0\n",
+       "        sample_dims:                ['chain', 'draw']
" + ], + "text/plain": [ + "\n", + "Group: /posterior_predictive\n", + " Dimensions: (chain: 1, draw: 500, time: 100, state: 3,\n", + " observed_state: 1)\n", + " Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 4kB 0 1 2 3 4 ... 495 496 497 498 499\n", + " * time (time) int64 800B 0 1 2 3 4 5 ... 95 96 97 98 99\n", + " * state (state) \n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree 'posterior_predictive'>\n",
+       "Group: /posterior_predictive\n",
+       "    Dimensions:         (chain: 1, draw: 500, time: 100, state: 3, observed_state: 1)\n",
+       "    Coordinates:\n",
+       "      * chain           (chain) int64 8B 0\n",
+       "      * draw            (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
+       "      * time            (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n",
+       "      * state           (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "      * observed_state  (observed_state) <U3 12B 'ts1'\n",
+       "    Data variables:\n",
+       "        prior_latent    (chain, draw, time, state) float64 1MB -5.512 ... 6.971e+23\n",
+       "        prior_observed  (chain, draw, time, observed_state) float64 400kB -5.512 ...\n",
+       "    Attributes:\n",
+       "        created_at:                 2026-05-20T16:02:37.272795+00:00\n",
+       "        creation_library:           ArviZ\n",
+       "        creation_library_version:   1.1.0\n",
+       "        creation_library_language:  Python\n",
+       "        inference_library:          pymc\n",
+       "        inference_library_version:  6.0.0\n",
+       "        sample_dims:                ['chain', 'draw']
" + ], + "text/plain": [ + "\n", + "Group: /posterior_predictive\n", + " Dimensions: (chain: 1, draw: 500, time: 100, state: 3, observed_state: 1)\n", + " Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", + " * time (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n", + " * state (state) \n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "unpost = ar3.sample_unconditional_posterior(idata, mvn_method=\"cholesky\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b112c551", + "metadata": {}, + "outputs": [], + "source": [ + "unpost" + ] + }, + { + "cell_type": "markdown", + "id": "174c0e9e", + "metadata": {}, + "source": [ + "# MVN Method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c456c29a", + "metadata": {}, + "outputs": [], + "source": [ + "class AutoRegressive3TwoSeries(PyMCStateSpace):\n", + " def __init__(self, mode: str):\n", + " k_states = 6 # 2 series × 3 lags\n", + " k_posdef = 2 # one innovation per series\n", + " k_endog = 2 # two observed series\n", + "\n", + " super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef, mode=mode)\n", + "\n", + " def make_symbolic_graph(self):\n", + " x0 = self.make_and_register_variable(\"x0\", shape=(6,))\n", + " P0 = self.make_and_register_variable(\"P0\", shape=(6, 6))\n", + "\n", + " ar_params = self.make_and_register_variable(\"ar_params\", shape=(2, 3))\n", + " sigma_x = self.make_and_register_variable(\"sigma_x\", shape=(2,))\n", + "\n", + " T = np.eye(6, k=-1)\n", + "\n", + " self.ssm[\"transition\", :, :] = T\n", + " self.ssm[\"transition\", 0, 0:3] = ar_params[0]\n", + " self.ssm[\"transition\", 3, 3:6] = ar_params[1]\n", + "\n", + " self.ssm[\"selection\", 0, 0] = 1\n", + " self.ssm[\"selection\", 3, 1] = 1\n", + "\n", + " self.ssm[\"state_cov\", :, :] = pt.diag(sigma_x)\n", + "\n", + " Z = np.zeros((2, 6))\n", + " Z[0, 0] = 1\n", + " Z[1, 3] = 1\n", + " self.ssm[\"design\", :, :] = Z\n", + "\n", + " self.ssm[\"initial_state\", :] = x0\n", + " self.ssm[\"initial_state_cov\", :, :] = P0\n", + "\n", + " def set_parameters(self):\n", + " x0 = Parameter(name=\"x0\", shape=(6,), dims=(ALL_STATE_DIM,))\n", + " P0 = Parameter(name=\"P0\", shape=(6, 6), dims=(ALL_STATE_DIM, ALL_STATE_AUX_DIM))\n", + "\n", + " ar_params = Parameter(\n", + " name=\"ar_params\",\n", + " shape=(2, 3),\n", + " dims=(\"observed_state\", \"ar_lags\"),\n", + " )\n", + "\n", + " sigma_x = Parameter(\n", + " name=\"sigma_x\",\n", + " shape=(2,),\n", + " dims=(\"observed_state\",),\n", + " )\n", + "\n", + " return x0, P0, ar_params, sigma_x\n", + "\n", + " def set_states(self):\n", + " # Observed states\n", + " ts1 = State(name=\"ts1\", observed=True)\n", + " ts2 = State(name=\"ts2\", observed=True)\n", + "\n", + " # Series 1 states\n", + " L1_s1 = State(name=\"L1.ts1\", observed=False)\n", + " L2_s1 = State(name=\"L2.ts1\", observed=False)\n", + " L3_s1 = State(name=\"L3.ts1\", observed=False)\n", + "\n", + " # Series 2 states\n", + " L1_s2 = State(name=\"L1.ts2\", observed=False)\n", + " L2_s2 = State(name=\"L2.ts2\", observed=False)\n", + " L3_s2 = State(name=\"L3.ts2\", observed=False)\n", + "\n", + " return (\n", + " ts1,\n", + " ts2,\n", + " L1_s1,\n", + " L2_s1,\n", + " L3_s1,\n", + " L1_s2,\n", + " L2_s2,\n", + " L3_s2,\n", + " )\n", + "\n", + " def set_shocks(self):\n", + " eps1 = Shock(name=\"innovation.ts1\")\n", + " eps2 = Shock(name=\"innovation.ts2\")\n", + " return eps1, eps2\n", + "\n", + " def set_coords(self):\n", + " default_coords = self.default_coords()\n", + " ar_coord = Coord(dimension=\"ar_lags\", labels=(1, 2, 3))\n", + " return *default_coords, ar_coord" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "26f6673b", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.coords" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8759374", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.ssm[\"design\"].eval()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c34976d9", + "metadata": {}, + "outputs": [], + "source": [ + "ar3 = AutoRegressive3TwoSeries(mode=\"NUMBA\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f66a14c9", + "metadata": {}, + "outputs": [], + "source": [ + "with pm.Model(coords=ar3.coords) as pymc_mod:\n", + " x0 = pm.Deterministic(\"x0\", pt.zeros(6), dims=[\"state\"])\n", + " P0 = pm.Deterministic(\"P0\", pt.eye(6) * 10, dims=[\"state\", \"state_aux\"])\n", + "\n", + " # global mean per lag\n", + " rho_global = pm.Normal(\"rho_global\", 0.0, 0.5, dims=[\"ar_lags\"])\n", + " tau = pm.Exponential(\"tau\", 2.0, dims=[\"ar_lags\"])\n", + "\n", + " ar_offset = pm.Normal(\"ar_offset\", 0.0, 1.0, dims=[\"observed_state\", \"ar_lags\"])\n", + "\n", + " ar_params = pm.Deterministic(\n", + " \"ar_params\",\n", + " rho_global + tau * ar_offset,\n", + " dims=[\"observed_state\", \"ar_lags\"],\n", + " )\n", + "\n", + " sigma_x = pm.Exponential(\"sigma_x\", 1.0, dims=[\"observed_state\"])\n", + "\n", + " ar3.build_statespace_graph(data=data)\n", + " idata = pm.sample(compile_kwargs={\"mode\": \"NUMBA\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c8ea114", + "metadata": {}, + "outputs": [], + "source": [ + "pymc_mod.to_graphviz()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c16eee6", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-extras", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.14.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 8e838b648..e8dbb22c4 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -14,7 +14,7 @@ from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import RandomState from pytensor.graph.basic import Variable -from pytensor.graph.replace import graph_replace +from pytensor.graph.replace import graph_replace, vectorize_graph from pytensor.graph.traversal import explicit_graph_inputs from rich.box import SIMPLE_HEAD from rich.console import Console @@ -52,6 +52,7 @@ from pymc_extras.statespace.utils.constants import ( ALL_STATE_AUX_DIM, ALL_STATE_DIM, + BATCH_DIM, FILTER_OUTPUT_DIMS, FILTER_OUTPUT_TYPES, JITTER_DEFAULT, @@ -873,8 +874,18 @@ def _insert_random_variables(self): matrices = list(self._unpack_statespace_with_placeholders()) - replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()} - self.subbed_ssm = graph_replace(matrices, replace=replacement_dict, strict=True) + if pymc_model["x0"].type.ndim > 1: + self.batch_dims = pymc_model["x0"].type.shape[0] + replacement_dict = { + var: pymc_model[name] for name, var in self._name_to_variable.items() + } + self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict) + else: + self.batch_dims = None + replacement_dict = { + var: pymc_model[name] for name, var in self._name_to_variable.items() + } + self.subbed_ssm = graph_replace(matrices, replace=replacement_dict) def _insert_data_variables(self): """ @@ -1011,6 +1022,115 @@ def _register_kalman_filter_outputs_with_pymc_model(outputs: tuple[pt.TensorVari dims = tuple([dim if dim in coords.keys() else None for dim in dim_names]) pm.Deterministic(name, var, dims=dims) + def make_vectorized_filter( + self, + missing_fill_value, + cov_jitter, + time_varying_names, + ): + def maybe_tv(name, base_sig): + """ + Add leading time dimension if matrix is time varying. + """ + if name in time_varying_names: + return f"(t,{base_sig})" + return f"({base_sig})" + + def vectorize_filter(data, x0, P0, c, d, T, Z, R, H, Q): + return self.kalman_filter.build_graph( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + missing_fill_value, + cov_jitter, + time_varying_names, + ) + + input_signature = ",".join( + [ + "(t,o)", # data + "(k)", # Initial state + "(k,k)", # Initial cov + maybe_tv("c", "k"), # state eq. intercept + maybe_tv("d", "o"), # observation eq. intercept + maybe_tv("T", "k,k"), # transition + maybe_tv("Z", "o,k"), # design + maybe_tv("R", "k,r"), # selection + maybe_tv("H", "o,o"), # observation cov + maybe_tv("Q", "r,r"), # process cov + ] + ) + + output_signature = ",".join( + [ + "(t,k)", # filtered states + "(t,k)", # predicted states + "(t,o)", # forecasts + "(t,k,k)", # filtered covs + "(t,k,k)", # predicted covs + "(t,o,o)", # forecast covs + "(t)", # loglikelihoods + ] + ) + + signature = f"{input_signature}->{output_signature}" + + return pt.vectorize( + vectorize_filter, + signature=signature, + ) + + def make_vectorized_smoother(self, cov_jitter, time_varying_names): + def maybe_tv(name, base_sig): + """ + Add leading time dimension if matrix is time varying. + """ + if name in time_varying_names: + return f"(t,{base_sig})" + return f"({base_sig})" + + def vectorize_smoother(T, R, Q, filtered_states, filtered_covariances): + return self.kalman_smoother.build_graph( + T, + R, + Q, + filtered_states, + filtered_covariances, + cov_jitter, + time_varying_names, + ) + + input_signature = ",".join( + [ + maybe_tv("T", "k,k"), # transition + maybe_tv("R", "k,r"), # selection + maybe_tv("Q", "r,r"), # process cov + "(t, k)", # filtered_states + "(t, k, k)", # filtered_covariances + ] + ) + + output_signature = ",".join( + [ + "(t,k)", # smoothed states + "(t,k,k)", # smoothed covs + ] + ) + + signature = f"{input_signature}->{output_signature}" + + return pt.vectorize( + vectorize_smoother, + signature=signature, + ) + def build_statespace_graph( self, data: np.ndarray | pd.DataFrame | pt.TensorVariable, @@ -1092,7 +1212,6 @@ def build_statespace_graph( self.mode = mode pm_mod = modelcontext(None) - self._insert_random_variables() self._save_exogenous_data_info() self._insert_data_variables() @@ -1100,24 +1219,41 @@ def build_statespace_graph( obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) self._fit_data = data + data_dims = None + + if data.ndim > 2: + data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) + data, nan_mask = register_data_with_pymc( data, n_obs=self.ssm.k_endog, obs_coords=obs_coords, register_data=register_data, missing_fill_value=missing_fill_value, + data_dims=data_dims, ) # Order is important here: only call _insert_data_shape_into_n_timesteps after data has been registered. self._insert_data_shape_into_n_timesteps(data) - filter_outputs = self.kalman_filter.build_graph( - pt.as_tensor_variable(data), - *self.unpack_statespace(), - missing_fill_value=missing_fill_value, - cov_jitter=cov_jitter, - time_varying_names=self.ssm.time_varying_names, - ) + if data_dims and BATCH_DIM in data_dims: + vectorized_filter = self.make_vectorized_filter( + missing_fill_value=missing_fill_value, + cov_jitter=cov_jitter, + time_varying_names=self.ssm.time_varying_names, + ) + + filter_outputs = vectorized_filter( + pt.as_tensor_variable(data), *self.unpack_statespace() + ) + else: + filter_outputs = self.kalman_filter.build_graph( + pt.as_tensor_variable(data), + *self.unpack_statespace(), + missing_fill_value=missing_fill_value, + cov_jitter=cov_jitter, + time_varying_names=self.ssm.time_varying_names, + ) logp = filter_outputs.pop(-1) states, covs = filter_outputs[:3], filter_outputs[3:] @@ -1133,6 +1269,9 @@ def build_statespace_graph( obs_dims = FILTER_OUTPUT_DIMS["predicted_observed_states"] obs_dims = obs_dims if all([dim in pm_mod.coords.keys() for dim in obs_dims]) else None + if data_dims and BATCH_DIM in data_dims: + obs_dims = (BATCH_DIM, *obs_dims) + SequenceMvNormal( "obs", mus=observed_states, @@ -1216,6 +1355,8 @@ def _build_dummy_graph(self) -> None: def infer_variable_shape(name): shape = self._name_to_variable[name].type.shape + if self.batch_dims: + shape = (self.batch_dims, *shape) if not any(dim is None for dim in shape): return shape @@ -1299,6 +1440,9 @@ def _kalman_filter_outputs_from_dummy_graph( obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) + if data.ndim > 2: + data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) + data, nan_mask = register_data_with_pymc( data, n_obs=self.ssm.k_endog, @@ -1307,19 +1451,37 @@ def _kalman_filter_outputs_from_dummy_graph( register_data=True, ) - filter_outputs = self.kalman_filter.build_graph( - data, - x0, - P0, - c, - d, - T, - Z, - R, - H, - Q, - time_varying_names=self.ssm.time_varying_names, - ) + if data_dims and BATCH_DIM in data_dims: + vectorized_filter = self.make_vectorized_filter( + None, None, time_varying_names=self.ssm.time_varying_names + ) + + filter_outputs = vectorized_filter( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + ) + else: + filter_outputs = self.kalman_filter.build_graph( + data, + x0, + P0, + c, + d, + T, + Z, + R, + H, + Q, + time_varying_names=self.ssm.time_varying_names, + ) filter_outputs.pop(-1) states, covariances = filter_outputs[:3], filter_outputs[3:] @@ -1327,14 +1489,24 @@ def _kalman_filter_outputs_from_dummy_graph( filtered_states, predicted_states, _ = states filtered_covariances, predicted_covariances, _ = covariances - [smoothed_states, smoothed_covariances] = self.kalman_smoother.build_graph( - T, - R, - Q, - filtered_states, - filtered_covariances, - time_varying_names=self.ssm.time_varying_names, - ) + if data_dims and BATCH_DIM in data_dims: + vectorized_smoother = self.make_vectorized_smoother(1e-8, self.ssm.time_varying_names) + [smoothed_states, smoothed_covariances] = vectorized_smoother( + T, + R, + Q, + filtered_states, + filtered_covariances, + ) + else: + [smoothed_states, smoothed_covariances] = self.kalman_smoother.build_graph( + T, + R, + Q, + filtered_states, + filtered_covariances, + time_varying_names=self.ssm.time_varying_names, + ) grouped_outputs = [ (filtered_states, filtered_covariances), @@ -1417,16 +1589,38 @@ def _sample_conditional( for name, (mu, cov) in zip(FILTER_OUTPUT_TYPES, grouped_outputs): dummy_ll = pt.zeros_like(mu) - state_dims = ( - (TIME_DIM, ALL_STATE_DIM) - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM]]) - else (None, None) - ) - obs_dims = ( - (TIME_DIM, OBS_STATE_DIM) - if all([dim in self._fit_coords for dim in [TIME_DIM, OBS_STATE_DIM]]) - else (None, None) - ) + if self.batch_dims: + state_dims = ( + (BATCH_DIM, TIME_DIM, ALL_STATE_DIM) + if all( + [ + dim in self._fit_coords + for dim in [BATCH_DIM, TIME_DIM, ALL_STATE_DIM] + ] + ) + else (None, None, None) + ) + obs_dims = ( + (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) + if all( + [ + dim in self._fit_coords + for dim in [BATCH_DIM, TIME_DIM, OBS_STATE_DIM] + ] + ) + else (None, None, None) + ) + else: + state_dims = ( + (TIME_DIM, ALL_STATE_DIM) + if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM]]) + else (None, None) + ) + obs_dims = ( + (TIME_DIM, OBS_STATE_DIM) + if all([dim in self._fit_coords for dim in [TIME_DIM, OBS_STATE_DIM]]) + else (None, None) + ) SequenceMvNormal( f"{name}_{group}", @@ -1551,8 +1745,18 @@ def _sample_unconditional( else: steps = len(temp_coords[TIME_DIM]) - 1 - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]]): - dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + if self.batch_dims: + if all( + [ + dim in self._fit_coords + for dim in [BATCH_DIM, TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + ] + ): + dims = [BATCH_DIM, TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + + else: + if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] with pm.Model(coords=temp_coords if dims is not None else None) as forward_model: self._build_dummy_graph() diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index ed98ee977..d291033c0 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -157,6 +157,9 @@ def rv_op( append_x0=True, method="svd", ): + def bmv(A, x): + return pt.matmul(A, x[..., None])[..., 0] + if sequence_names is None: sequence_names = [] @@ -203,8 +206,8 @@ def step_fn(*args): for src_idx, dst_idx in enumerate(non_seq_positions): ordered[dst_idx] = non_seqs[src_idx] c, d, T, Z, R, H, Q = ordered - k = T.shape[0] - a = state[:k] + k = T.shape[-1] + a = state[..., :k] middle_rng, a_innovation = pm.MvNormal.dist( mu=0, cov=Q, rng=rng, method=method, return_next_rng=True @@ -213,13 +216,13 @@ def step_fn(*args): mu=0, cov=H, rng=middle_rng, method=method, return_next_rng=True ) - a_mu = c + T @ a - a_next = a_mu + R @ a_innovation + a_mu = c + bmv(T, a) + a_next = a_mu + bmv(R, a_innovation) - y_mu = d + Z @ a_next + y_mu = d + bmv(Z, a_next) y_next = y_mu + y_innovation - next_state = pt.concatenate([a_next, y_next], axis=0) + next_state = pt.concatenate([a_next, y_next], axis=-1) return next_rng, next_state @@ -227,9 +230,9 @@ def step_fn(*args): H_init = H_ if H_ in non_sequences else H_[0] init_x_ = pm.MvNormal.dist(a0_, P0_, rng=rng, method=method) - init_y_ = pm.MvNormal.dist(Z_init @ init_x_, H_init, rng=rng, method=method) + init_y_ = pm.MvNormal.dist(bmv(Z_init, init_x_), H_init, rng=rng, method=method) - init_dist_ = pt.concatenate([init_x_, init_y_], axis=0) + init_dist_ = pt.concatenate([init_x_, init_y_], axis=-1) ss_rng, statespace = pytensor.scan( step_fn, @@ -242,11 +245,12 @@ def step_fn(*args): ) if append_x0: - statespace_ = pt.concatenate([init_dist_[None], statespace], axis=0) - statespace_ = pt.specify_shape(statespace_, (steps + 1, None)) + init_dist_expanded = pt.expand_dims(init_dist_, axis=0) + statespace_ = pt.concatenate([init_dist_expanded, statespace], axis=0) + # statespace_ = pt.specify_shape(statespace_, (steps + 1, None)) else: statespace_ = statespace - statespace_ = pt.specify_shape(statespace_, (steps, None)) + # statespace_ = pt.specify_shape(statespace_, (steps, None)) linear_gaussian_ss_op = LinearGaussianStateSpaceRV( inputs=[a0_, P0_, c_, d_, T_, Z_, R_, H_, Q_, steps, rng], @@ -287,7 +291,15 @@ def __new__( dims = kwargs.pop("dims", None) latent_dims = None obs_dims = None - if dims is not None: + if dims is not None and len(dims) > 3: + # if len(dims) != 3: + # ValueError( + # "LinearGaussianStateSpace expects 3 dims: time, all_states, and observed_states" + # ) + batch_dim, time_dim, state_dim, obs_dim = dims + latent_dims = [time_dim, batch_dim, state_dim] + obs_dims = [time_dim, batch_dim, obs_dim] + elif dims is not None: if len(dims) != 3: ValueError( "LinearGaussianStateSpace expects 3 dims: time, all_states, and observed_states" @@ -313,7 +325,7 @@ def __new__( method=method, **kwargs, ) - latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None)) + # latent_obs_combined = pt.specify_shape(latent_obs_combined, (steps + int(append_x0), None)) if k_endog is None: k_endog = cls._get_k_endog(H) latent_slice = slice(None, -k_endog) diff --git a/pymc_extras/statespace/utils/constants.py b/pymc_extras/statespace/utils/constants.py index 8c7b399fe..cdbfff9d7 100644 --- a/pymc_extras/statespace/utils/constants.py +++ b/pymc_extras/statespace/utils/constants.py @@ -15,6 +15,7 @@ FACTOR_DIM = "factor" ERROR_AR_PARAM_DIM = "error_lag_ar" EXOG_STATE_DIM = "exogenous" +BATCH_DIM = "batch" MISSING_FILL = -9999.0 JITTER_DEFAULT = 1e-8 if pytensor.config.floatX.endswith("64") else 1e-6 diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index 467ae314f..c1255e2ef 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -10,11 +10,7 @@ from pymc.exceptions import ImputationWarning from pytensor.tensor.sharedvar import TensorSharedVariable -from pymc_extras.statespace.utils.constants import ( - MISSING_FILL, - OBS_STATE_DIM, - TIME_DIM, -) +from pymc_extras.statespace.utils.constants import BATCH_DIM, MISSING_FILL, OBS_STATE_DIM, TIME_DIM NO_TIME_INDEX_WARNING = ( "No time index found on the supplied data. A simple range index will be automatically " @@ -36,11 +32,13 @@ def get_data_dims(data): return data_dims -def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None): +def _validate_data_shape( + data_shape, n_obs, obs_coords=None, check_col_names=False, col_names=None, batched=False +): if col_names is None: col_names = [] - if len(data_shape) != 2: + if not batched and len(data_shape) != 2: raise ValueError("Data must be a 2d matrix") if data_shape[-1] != n_obs: @@ -59,22 +57,27 @@ def _validate_data_shape(data_shape, n_obs, obs_coords=None, check_col_names=Fal ) -def preprocess_tensor_data(data, n_obs, obs_coords=None): +def preprocess_tensor_data(data, n_obs, obs_coords=None, batched=False): data_shape = data.shape.eval() - _validate_data_shape(data_shape, n_obs, obs_coords) + _validate_data_shape(data_shape, n_obs, obs_coords, batched=batched) if obs_coords is not None: warnings.warn(NO_TIME_INDEX_WARNING) - index = np.arange(data_shape[0], dtype="int") + + index = ( + np.arange(data_shape[0], dtype="int") + if not batched + else np.arange(data_shape[1], dtype="int") + ) return data.eval(), index -def preprocess_numpy_data(data, n_obs, obs_coords=None): - _validate_data_shape(data.shape, n_obs, obs_coords) +def preprocess_numpy_data(data, n_obs, obs_coords=None, batched=False): + _validate_data_shape(data.shape, n_obs, obs_coords, batched=batched) if obs_coords is not None: warnings.warn(NO_TIME_INDEX_WARNING) - index = np.arange(data.shape[0], dtype="int") + index = np.arange(data.shape[0], dtype="int") if not batched else np.arange(data.shape[1]) return data, index @@ -122,11 +125,14 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals return preprocess_numpy_data(data.values, n_obs, obs_coords) -def add_data_to_active_model(values, index, data_dims=None): +def add_data_to_active_model(values, index, data_dims=None, batched=False): pymc_mod = modelcontext(None) - if data_dims is None: + if data_dims is None and not batched: data_dims = [TIME_DIM, OBS_STATE_DIM] - time_dim = data_dims[0] + else: + data_dims = [BATCH_DIM, TIME_DIM, OBS_STATE_DIM] + + time_dim = data_dims[0] if not batched else data_dims[1] if isinstance(index, pd.Index): index = index.rename(time_dim) @@ -145,10 +151,14 @@ def add_data_to_active_model(values, index, data_dims=None): # If the data has just one column, we need to specify the shape as (None, 1), or else the JAX backend will # raise a broadcasting error. - if values.shape[-1] == 1 or values.ndim == 1: + if (values.shape[-1] == 1 or values.ndim == 1) and not batched: data_shape = (None, 1) - else: + elif (values.shape[-1] == 1 or values.ndim == 1) and batched: + data_shape = (values.shape[0], None, 1) + elif not batched: data_shape = (None, values.shape[-1]) + else: + data_shape = (values.shape[0], None, *values.shape[2:]) data = pm.Data("data", values, dims=data_dims, shape=data_shape) @@ -184,10 +194,14 @@ def mask_missing_values_in_data(values, missing_fill_value=None): def register_data_with_pymc( data, n_obs, obs_coords, register_data=True, missing_fill_value=None, data_dims=None ): + batched = False + if data_dims and BATCH_DIM in data_dims: + batched = True + if isinstance(data, pt.TensorVariable | TensorSharedVariable): - values, index = preprocess_tensor_data(data, n_obs, obs_coords) + values, index = preprocess_tensor_data(data, n_obs, obs_coords, batched) elif isinstance(data, np.ndarray): - values, index = preprocess_numpy_data(data, n_obs, obs_coords) + values, index = preprocess_numpy_data(data, n_obs, obs_coords, batched) elif isinstance(data, pd.DataFrame | pd.Series): values, index = preprocess_pandas_data(data, n_obs, obs_coords) else: @@ -196,7 +210,7 @@ def register_data_with_pymc( data, nan_mask = mask_missing_values_in_data(values, missing_fill_value) if register_data: - data = add_data_to_active_model(data, index, data_dims) + data = add_data_to_active_model(data, index, data_dims, batched) else: data = pytensor.shared(data, name="data") return data, nan_mask From ded6f3f34f04f0c2f2f599e2c491cf88b79fc4ac Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Thu, 21 May 2026 08:11:54 -0600 Subject: [PATCH 02/10] added a parameter to explicitly define the batch size, updated VARMAX, SARIMAX, ETS so that positional arguments don't clash with the new parameter, updated error in control flow for data registry --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 1945 ++--------------- pymc_extras/statespace/core/statespace.py | 18 +- pymc_extras/statespace/models/ETS.py | 2 +- pymc_extras/statespace/models/SARIMAX.py | 2 +- pymc_extras/statespace/models/VARMAX.py | 2 +- pymc_extras/statespace/utils/data_tools.py | 9 +- 6 files changed, 191 insertions(+), 1787 deletions(-) diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb index 5b25d958c..f3a0d405a 100644 --- a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -16,7 +16,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d95e629f", "metadata": {}, "outputs": [], @@ -41,7 +41,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "34b32aa3", "metadata": {}, "outputs": [], @@ -51,8 +51,11 @@ " k_states = 3 # size of the state vector x\n", " k_posdef = 1 # number of shocks (size of the state covariance matrix Q)\n", " k_endog = 1 # number of observed states\n", + " batch_size = 2\n", "\n", - " super().__init__(k_endog=k_endog, k_states=k_states, k_posdef=k_posdef, mode=mode)\n", + " super().__init__(\n", + " k_endog=k_endog, k_states=k_states, k_posdef=k_posdef, batch_size=2, mode=mode\n", + " )\n", "\n", " def make_symbolic_graph(self):\n", " x0 = self.make_and_register_variable(\"x0\", shape=(3,))\n", @@ -112,51 +115,17 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "bfae06f5", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
                          Model Requirements                           \n",
-       "                                                                       \n",
-       "  Variable    Shape    Constraints                         Dimensions  \n",
-       " ───────────────────────────────────────────────────────────────────── \n",
-       "  x0          (3,)                                         ('state',)  \n",
-       "  P0          (3, 3)                           ('state', 'state_aux')  \n",
-       "  ar_params   (3,)     Stationary, please :)             ('ar_lags',)  \n",
-       "  sigma_x     (1,)                                         ('shock',)  \n",
-       "                                                                       \n",
-       " These parameters should be assigned priors inside a PyMC model block  \n",
-       "           before calling the build_statespace_graph method.           \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[3m Model Requirements \u001b[0m\n", - " \n", - " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", - " ───────────────────────────────────────────────────────────────────── \n", - " x0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m,\u001b[1m)\u001b[0m \n", - " P0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", - " ar_params \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m Stationary, please :\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'ar_lags'\u001b[0m,\u001b[1m)\u001b[0m \n", - " sigma_x \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'shock'\u001b[0m,\u001b[1m)\u001b[0m \n", - " \n", - "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block \u001b[0m\n", - "\u001b[2;3m before calling the build_statespace_graph method. \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "ar3 = AutoRegressiveThree(mode=\"NUMBA\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "27119d27", "metadata": {}, "outputs": [], @@ -166,7 +135,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "d9163bda", "metadata": {}, "outputs": [], @@ -179,16 +148,7 @@ "execution_count": null, "id": "217b812d", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:77: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", - " warnings.warn(NO_TIME_INDEX_WARNING)\n" - ] - } - ], + "outputs": [], "source": [ "# Not vectorized\n", "with pm.Model(coords=ar3.coords) as pymc_mod:\n", @@ -223,51 +183,10 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "ed208bcd", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Initializing NUTS using jitter+adapt_diag...\n", - "Multiprocess sampling (4 chains in 4 jobs)\n", - "NUTS: [ar_params, sigma_x]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "54b5be47bbc44f7399db826f05108517", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 2 seconds.\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "with pymc_mod:\n",
     "    idata = pm.sample(tune=200, draws=200, compile_kwargs={\"mode\": \"NUMBA\"})"
@@ -275,62 +194,20 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 12,
+   "execution_count": null,
    "id": "cf2af58a",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:77: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n",
-      "  warnings.warn(NO_TIME_INDEX_WARNING)\n",
-      "Sampling: [filtered_posterior, filtered_posterior_observed, predicted_posterior, predicted_posterior_observed, smoothed_posterior, smoothed_posterior_observed]\n"
-     ]
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "fe57e8f78975419887f133c358e24166",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    }
-   ],
+   "outputs": [],
    "source": [
     "post = ar3.sample_conditional_posterior(idata, mvn_method=\"cholesky\")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 13,
+   "execution_count": null,
    "id": "5c3800a9",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Sampling: [ar_params, obs, sigma_x]\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "with pymc_mod:\n",
     "    prior = pm.sample_prior_predictive(compile_kwargs={\"mode\": \"NUMBA\"})"
@@ -338,1654 +215,30 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": null,
    "id": "5857a994",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:77: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n",
-      "  warnings.warn(NO_TIME_INDEX_WARNING)\n",
-      "Sampling: [filtered_prior, filtered_prior_observed, predicted_prior, predicted_prior_observed, smoothed_prior, smoothed_prior_observed]\n"
-     ]
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "c47107700d01475fbc7f96b068ea0aec",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree 'posterior_predictive'>\n",
-       "Group: /posterior_predictive\n",
-       "    Dimensions:                   (chain: 1, draw: 500, time: 100, state: 3,\n",
-       "                                   observed_state: 1)\n",
-       "    Coordinates:\n",
-       "      * chain                     (chain) int64 8B 0\n",
-       "      * draw                      (draw) int64 4kB 0 1 2 3 4 ... 495 496 497 498 499\n",
-       "      * time                      (time) int64 800B 0 1 2 3 4 5 ... 95 96 97 98 99\n",
-       "      * state                     (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "      * observed_state            (observed_state) <U3 12B 'ts1'\n",
-       "    Data variables:\n",
-       "        filtered_prior            (chain, draw, time, state) float64 1MB -0.1217 ...\n",
-       "        filtered_prior_observed   (chain, draw, time, observed_state) float64 400kB ...\n",
-       "        predicted_prior           (chain, draw, time, state) float64 1MB -1.732 ....\n",
-       "        predicted_prior_observed  (chain, draw, time, observed_state) float64 400kB ...\n",
-       "        smoothed_prior            (chain, draw, time, state) float64 1MB -0.1217 ...\n",
-       "        smoothed_prior_observed   (chain, draw, time, observed_state) float64 400kB ...\n",
-       "    Attributes:\n",
-       "        created_at:                 2026-05-20T16:02:32.953527+00:00\n",
-       "        creation_library:           ArviZ\n",
-       "        creation_library_version:   1.1.0\n",
-       "        creation_library_language:  Python\n",
-       "        inference_library:          pymc\n",
-       "        inference_library_version:  6.0.0\n",
-       "        sample_dims:                ['chain', 'draw']
" - ], - "text/plain": [ - "\n", - "Group: /posterior_predictive\n", - " Dimensions: (chain: 1, draw: 500, time: 100, state: 3,\n", - " observed_state: 1)\n", - " Coordinates:\n", - " * chain (chain) int64 8B 0\n", - " * draw (draw) int64 4kB 0 1 2 3 4 ... 495 496 497 498 499\n", - " * time (time) int64 800B 0 1 2 3 4 5 ... 95 96 97 98 99\n", - " * state (state) \n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree 'posterior_predictive'>\n",
-       "Group: /posterior_predictive\n",
-       "    Dimensions:         (chain: 1, draw: 500, time: 100, state: 3, observed_state: 1)\n",
-       "    Coordinates:\n",
-       "      * chain           (chain) int64 8B 0\n",
-       "      * draw            (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
-       "      * time            (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n",
-       "      * state           (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "      * observed_state  (observed_state) <U3 12B 'ts1'\n",
-       "    Data variables:\n",
-       "        prior_latent    (chain, draw, time, state) float64 1MB -5.512 ... 6.971e+23\n",
-       "        prior_observed  (chain, draw, time, observed_state) float64 400kB -5.512 ...\n",
-       "    Attributes:\n",
-       "        created_at:                 2026-05-20T16:02:37.272795+00:00\n",
-       "        creation_library:           ArviZ\n",
-       "        creation_library_version:   1.1.0\n",
-       "        creation_library_language:  Python\n",
-       "        inference_library:          pymc\n",
-       "        inference_library_version:  6.0.0\n",
-       "        sample_dims:                ['chain', 'draw']
" - ], - "text/plain": [ - "\n", - "Group: /posterior_predictive\n", - " Dimensions: (chain: 1, draw: 500, time: 100, state: 3, observed_state: 1)\n", - " Coordinates:\n", - " * chain (chain) int64 8B 0\n", - " * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", - " * time (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n", - " * state (state) \n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "unpost = ar3.sample_unconditional_posterior(idata, mvn_method=\"cholesky\")" ] @@ -2173,13 +426,163 @@ "pymc_mod.to_graphviz()" ] }, + { + "cell_type": "markdown", + "id": "5ac8b371", + "metadata": {}, + "source": [ + "# MISC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57a8c532", + "metadata": {}, + "outputs": [], + "source": [ + "from pymc_extras.statespace import BayesianVARMAX" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "6c16eee6", + "id": "75ded7b0", + "metadata": {}, + "outputs": [], + "source": [ + "def varma_mod(data):\n", + " return BayesianVARMAX(\n", + " endog_names=data.columns,\n", + " order=(2, 0),\n", + " stationary_initialization=True,\n", + " verbose=False,\n", + " measurement_error=True,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb52c037", + "metadata": {}, + "outputs": [], + "source": [ + "def idata(pymc_mod, rng):\n", + " with pymc_mod:\n", + " idata = pm.sample_prior_predictive(draws=10, random_seed=rng)\n", + "\n", + " return idata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63265e6c", + "metadata": {}, + "outputs": [], + "source": [ + "def pymc_mod(varma_mod, data):\n", + " with pm.Model(coords=varma_mod.coords) as pymc_mod:\n", + " state_chol, *_ = pm.LKJCholeskyCov(\n", + " \"state_chol\", n=varma_mod.k_posdef, eta=1, sd_dist=pm.Exponential.dist(1)\n", + " )\n", + " ar_params = pm.Normal(\n", + " \"ar_params\", mu=0, sigma=0.1, dims=[\"observed_state\", \"lag_ar\", \"observed_state_aux\"]\n", + " )\n", + " state_cov = pm.Deterministic(\n", + " \"state_cov\", state_chol @ state_chol.T, dims=[\"shock\", \"shock_aux\"]\n", + " )\n", + " sigma_obs = pm.Exponential(\"sigma_obs\", 1, dims=[\"observed_state\"])\n", + "\n", + " varma_mod.build_statespace_graph(data=data, save_kalman_filter_outputs_in_idata=True)\n", + "\n", + " return pymc_mod" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2eb92902", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "461f37bd", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import pytensor\n", + "\n", + "floatX = pytensor.config.floatX" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c365d89b", + "metadata": {}, + "outputs": [], + "source": [ + "def data():\n", + " df = pd.read_csv(\n", + " \"../tests/statespace/_data/statsmodels_macrodata_processed.csv\",\n", + " index_col=0,\n", + " parse_dates=True,\n", + " ).astype(floatX)\n", + " df.index.freq = df.index.inferred_freq\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "443ad7aa", + "metadata": {}, + "outputs": [], + "source": [ + "data = data()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3426a7dd", + "metadata": {}, + "outputs": [], + "source": [ + "varma_mod_ = varma_mod(data)\n", + "pymc_mod_ = pymc_mod(varma_mod_, data)\n", + "idata_ = idata(pymc_mod_, rng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f32f55e", "metadata": {}, "outputs": [], "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e976e8b9", + "metadata": {}, + "outputs": [], + "source": [ + "def test_forecast(varma_mod, idata, rng):\n", + " forecast = varma_mod.forecast(idata.prior, periods=10, random_seed=rng)\n", + "\n", + " assert np.isfinite(forecast.forecast_latent.values).all()\n", + " assert np.isfinite(forecast.forecast_observed.values).all()" + ] } ], "metadata": { diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index e8dbb22c4..96a01e93c 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -259,6 +259,7 @@ def __init__( k_endog: int, k_states: int, k_posdef: int, + batch_size: int | None = None, filter_type: str = "standard", verbose: bool = True, measurement_error: bool = False, @@ -279,6 +280,7 @@ def __init__( self.k_posdef = k_posdef self.measurement_error = measurement_error self.mode = mode + self.batch_size = batch_size self._populate_properties() @@ -874,14 +876,12 @@ def _insert_random_variables(self): matrices = list(self._unpack_statespace_with_placeholders()) - if pymc_model["x0"].type.ndim > 1: - self.batch_dims = pymc_model["x0"].type.shape[0] + if self.batch_size: replacement_dict = { var: pymc_model[name] for name, var in self._name_to_variable.items() } self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict) else: - self.batch_dims = None replacement_dict = { var: pymc_model[name] for name, var in self._name_to_variable.items() } @@ -1221,7 +1221,7 @@ def build_statespace_graph( data_dims = None - if data.ndim > 2: + if self.batch_size: data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) data, nan_mask = register_data_with_pymc( @@ -1355,8 +1355,8 @@ def _build_dummy_graph(self) -> None: def infer_variable_shape(name): shape = self._name_to_variable[name].type.shape - if self.batch_dims: - shape = (self.batch_dims, *shape) + if self.batch_size: + shape = (self.batch_size, *shape) if not any(dim is None for dim in shape): return shape @@ -1440,7 +1440,7 @@ def _kalman_filter_outputs_from_dummy_graph( obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) - if data.ndim > 2: + if self.batch_size: data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) data, nan_mask = register_data_with_pymc( @@ -1589,7 +1589,7 @@ def _sample_conditional( for name, (mu, cov) in zip(FILTER_OUTPUT_TYPES, grouped_outputs): dummy_ll = pt.zeros_like(mu) - if self.batch_dims: + if self.batch_size: state_dims = ( (BATCH_DIM, TIME_DIM, ALL_STATE_DIM) if all( @@ -1745,7 +1745,7 @@ def _sample_unconditional( else: steps = len(temp_coords[TIME_DIM]) - 1 - if self.batch_dims: + if self.batch_size: if all( [ dim in self._fit_coords diff --git a/pymc_extras/statespace/models/ETS.py b/pymc_extras/statespace/models/ETS.py index dbe6dfcfa..9e1d26294 100644 --- a/pymc_extras/statespace/models/ETS.py +++ b/pymc_extras/statespace/models/ETS.py @@ -289,7 +289,7 @@ def __init__( k_endog, k_states, k_posdef, - filter_type, + filter_type=filter_type, verbose=verbose, measurement_error=measurement_error, mode=mode, diff --git a/pymc_extras/statespace/models/SARIMAX.py b/pymc_extras/statespace/models/SARIMAX.py index 11c0c6850..aac51505b 100644 --- a/pymc_extras/statespace/models/SARIMAX.py +++ b/pymc_extras/statespace/models/SARIMAX.py @@ -265,7 +265,7 @@ def __init__( k_endog, k_states, k_posdef, - filter_type, + filter_type=filter_type, verbose=verbose, measurement_error=measurement_error, mode=mode, diff --git a/pymc_extras/statespace/models/VARMAX.py b/pymc_extras/statespace/models/VARMAX.py index 7677b092c..cab85fa72 100644 --- a/pymc_extras/statespace/models/VARMAX.py +++ b/pymc_extras/statespace/models/VARMAX.py @@ -198,7 +198,7 @@ def __init__( k_endog, k_states, k_posdef, - filter_type, + filter_type=filter_type, verbose=verbose, measurement_error=measurement_error, mode=mode, diff --git a/pymc_extras/statespace/utils/data_tools.py b/pymc_extras/statespace/utils/data_tools.py index c1255e2ef..8abc1c335 100644 --- a/pymc_extras/statespace/utils/data_tools.py +++ b/pymc_extras/statespace/utils/data_tools.py @@ -127,10 +127,11 @@ def preprocess_pandas_data(data, n_obs, obs_coords=None, check_column_names=Fals def add_data_to_active_model(values, index, data_dims=None, batched=False): pymc_mod = modelcontext(None) - if data_dims is None and not batched: - data_dims = [TIME_DIM, OBS_STATE_DIM] - else: - data_dims = [BATCH_DIM, TIME_DIM, OBS_STATE_DIM] + if data_dims is None: + if not batched: + data_dims = [TIME_DIM, OBS_STATE_DIM] + else: + data_dims = [BATCH_DIM, TIME_DIM, OBS_STATE_DIM] time_dim = data_dims[0] if not batched else data_dims[1] From d58eb1f97d19661bc93f4b18e26276c93c001363 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Fri, 22 May 2026 11:04:02 -0600 Subject: [PATCH 03/10] updated forecast method to support batched data --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 38 +++++++++++- pymc_extras/statespace/core/statespace.py | 59 ++++++++++++++----- 2 files changed, 78 insertions(+), 19 deletions(-) diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb index f3a0d405a..966eafdc8 100644 --- a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -11,7 +11,15 @@ "* `sample_unconditional_prior`\n", "* `sample_conditional_prior`\n", "* `sample_unconditional_posterior`\n", - "* `sample_conditional_posterior`" + "* `sample_conditional_posterior`\n", + "* `forecast`\n", + "\n", + "\n", + "Need to also add\n", + "\n", + "* `sample_statespace_matrices`\n", + "* `sample_filter_outputs`\n", + "* `impulse_response_function`" ] }, { @@ -51,10 +59,14 @@ " k_states = 3 # size of the state vector x\n", " k_posdef = 1 # number of shocks (size of the state covariance matrix Q)\n", " k_endog = 1 # number of observed states\n", - " batch_size = 2\n", + " # batch_size = 2\n", "\n", " super().__init__(\n", - " k_endog=k_endog, k_states=k_states, k_posdef=k_posdef, batch_size=2, mode=mode\n", + " k_endog=k_endog,\n", + " k_states=k_states,\n", + " k_posdef=k_posdef,\n", + " # batch_size=batch_size,\n", + " mode=mode,\n", " )\n", "\n", " def make_symbolic_graph(self):\n", @@ -192,6 +204,26 @@ " idata = pm.sample(tune=200, draws=200, compile_kwargs={\"mode\": \"NUMBA\"})" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "64940705", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.sample_filter_outputs(idata, filter_output_names=[\"filtered_states\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92b4b9f3", + "metadata": {}, + "outputs": [], + "source": [ + "ar3.forecast(idata, periods=10)" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 96a01e93c..5d6c1ddc3 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -1440,7 +1440,7 @@ def _kalman_filter_outputs_from_dummy_graph( obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) - if self.batch_size: + if self.batch_size and data_dims is None: data_dims = (BATCH_DIM, TIME_DIM, OBS_STATE_DIM) data, nan_mask = register_data_with_pymc( @@ -2553,23 +2553,38 @@ def _build_forecast_model( filter_time_dim = TIME_DIM temp_coords = self._fit_coords.copy() - dims = None - if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): - dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] - t0_idx = np.flatnonzero(time_index == t0)[0] temp_coords["data_time"] = time_index temp_coords[TIME_DIM] = forecast_index - mu_dims, cov_dims = None, None - if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]): - mu_dims = ["data_time", ALL_STATE_DIM] - cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM] + dims, mu_dims, cov_dims = None, None, None + + if self.batch_size: + data_dims = [BATCH_DIM, "data_time", OBS_STATE_DIM] + if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [BATCH_DIM, TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + if all( + [ + dim in self._fit_coords + for dim in [BATCH_DIM, TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM] + ] + ): + mu_dims = [BATCH_DIM, ALL_STATE_DIM] + cov_dims = [BATCH_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM] + else: + data_dims = ["data_time", OBS_STATE_DIM] + if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]): + dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM] + if all( + [dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]] + ): + mu_dims = [ALL_STATE_DIM] + cov_dims = [ALL_STATE_DIM, ALL_STATE_AUX_DIM] with pm.Model(coords=temp_coords) as forecast_model: _, grouped_outputs = self._kalman_filter_outputs_from_dummy_graph( - data_dims=["data_time", OBS_STATE_DIM], + data_dims=data_dims, ) group_idx = FILTER_OUTPUT_TYPES.index(filter_output) @@ -2588,12 +2603,24 @@ def _build_forecast_model( mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) - x0 = pm.Deterministic( - "x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None - ) - P0 = pm.Deterministic( - "P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None - ) + if self.batch_size: + x0 = pm.Deterministic( + "x0_slice", + mu_frozen[:, t0_idx, :], + dims=mu_dims if mu_dims is not None else None, + ) + P0 = pm.Deterministic( + "P0_slice", + cov_frozen[:, t0_idx, :, :], + dims=cov_dims if cov_dims is not None else None, + ) + else: + x0 = pm.Deterministic( + "x0_slice", mu_frozen[t0_idx], dims=mu_dims if mu_dims is not None else None + ) + P0 = pm.Deterministic( + "P0_slice", cov_frozen[t0_idx], dims=cov_dims if cov_dims is not None else None + ) # Get fresh matrices with n_timesteps placeholder still intact. # Build for the full timeline (training + forecast) so that time-varying matrices From 2f4f6b2d4afca00f0d2dffaa5ed28233db9565b4 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Fri, 22 May 2026 11:43:23 -0600 Subject: [PATCH 04/10] updated sample_filter_outputs to support batched data --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 800 +++++++++++++++++- pymc_extras/statespace/core/statespace.py | 98 ++- 2 files changed, 861 insertions(+), 37 deletions(-) diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb index 966eafdc8..042c20eb2 100644 --- a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -24,7 +24,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "d95e629f", "metadata": {}, "outputs": [], @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "34b32aa3", "metadata": {}, "outputs": [], @@ -127,17 +127,51 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "bfae06f5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
                          Model Requirements                           \n",
+       "                                                                       \n",
+       "  Variable    Shape    Constraints                         Dimensions  \n",
+       " ───────────────────────────────────────────────────────────────────── \n",
+       "  x0          (3,)                                         ('state',)  \n",
+       "  P0          (3, 3)                           ('state', 'state_aux')  \n",
+       "  ar_params   (3,)     Stationary, please :)             ('ar_lags',)  \n",
+       "  sigma_x     (1,)                                         ('shock',)  \n",
+       "                                                                       \n",
+       " These parameters should be assigned priors inside a PyMC model block  \n",
+       "           before calling the build_statespace_graph method.           \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Model Requirements \u001b[0m\n", + " \n", + " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", + " ───────────────────────────────────────────────────────────────────── \n", + " x0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m,\u001b[1m)\u001b[0m \n", + " P0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", + " ar_params \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m Stationary, please :\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'ar_lags'\u001b[0m,\u001b[1m)\u001b[0m \n", + " sigma_x \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'shock'\u001b[0m,\u001b[1m)\u001b[0m \n", + " \n", + "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block \u001b[0m\n", + "\u001b[2;3m before calling the build_statespace_graph method. \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "ar3 = AutoRegressiveThree(mode=\"NUMBA\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "id": "27119d27", "metadata": {}, "outputs": [], @@ -147,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "d9163bda", "metadata": {}, "outputs": [], @@ -157,10 +191,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "217b812d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", + " warnings.warn(NO_TIME_INDEX_WARNING)\n" + ] + } + ], "source": [ "# Not vectorized\n", "with pm.Model(coords=ar3.coords) as pymc_mod:\n", @@ -175,10 +218,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "84ac6e46", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", + " warnings.warn(NO_TIME_INDEX_WARNING)\n" + ] + } + ], "source": [ "# Vectorized\n", "with pm.Model(coords=ar3.coords | {\"batch\": [\"batch_1\", \"batch_2\"]}) as pymc_mod:\n", @@ -195,10 +247,51 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "ed208bcd", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [ar_params, sigma_x]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "51570fe8087f4c28aaf4ef7662d4e488", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 2 seconds.\n"
+     ]
+    }
+   ],
    "source": [
     "with pymc_mod:\n",
     "    idata = pm.sample(tune=200, draws=200, compile_kwargs={\"mode\": \"NUMBA\"})"
@@ -206,12 +299,689 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 15,
    "id": "64940705",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n",
+      "  warnings.warn(NO_TIME_INDEX_WARNING)\n",
+      "Sampling: []\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "32292b1a794e4fa7be31153fb56b19ea",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree>\n",
+       "Group: /\n",
+       "├── Group: /posterior_predictive\n",
+       "│       Dimensions:               (chain: 4, draw: 200, time: 100, state: 3,\n",
+       "│                                  state_aux: 3)\n",
+       "│       Coordinates:\n",
+       "│         * chain                 (chain) int64 32B 0 1 2 3\n",
+       "│         * draw                  (draw) int64 2kB 0 1 2 3 4 5 ... 195 196 197 198 199\n",
+       "│         * time                  (time) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99\n",
+       "│         * state                 (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "│         * state_aux             (state_aux) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "│       Data variables:\n",
+       "│           filtered_covariances  (chain, draw, time, state, state_aux) float64 6MB 1...\n",
+       "│       Attributes:\n",
+       "│           created_at:                 2026-05-22T17:38:10.782806+00:00\n",
+       "│           creation_library:           ArviZ\n",
+       "│           creation_library_version:   1.1.0\n",
+       "│           creation_library_language:  Python\n",
+       "│           inference_library:          pymc\n",
+       "│           inference_library_version:  6.0.0\n",
+       "│           sample_dims:                ['chain', 'draw']\n",
+       "├── Group: /observed_data\n",
+       "│       Attributes:\n",
+       "│           created_at:                 2026-05-22T17:38:10.784674+00:00\n",
+       "│           creation_library:           ArviZ\n",
+       "│           creation_library_version:   1.1.0\n",
+       "│           creation_library_language:  Python\n",
+       "│           inference_library:          pymc\n",
+       "│           inference_library_version:  6.0.0\n",
+       "│           sample_dims:                []\n",
+       "└── Group: /constant_data\n",
+       "        Dimensions:         (time: 100, observed_state: 1)\n",
+       "        Coordinates:\n",
+       "          * time            (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n",
+       "          * observed_state  (observed_state) <U3 12B 'ts1'\n",
+       "        Data variables:\n",
+       "            data            (time, observed_state) float64 800B 0.0816 ... 0.4303\n",
+       "        Attributes:\n",
+       "            created_at:                 2026-05-22T17:38:10.785346+00:00\n",
+       "            creation_library:           ArviZ\n",
+       "            creation_library_version:   1.1.0\n",
+       "            creation_library_language:  Python\n",
+       "            inference_library:          pymc\n",
+       "            inference_library_version:  6.0.0\n",
+       "            sample_dims:                []
" + ], + "text/plain": [ + "\n", + "Group: /\n", + "├── Group: /posterior_predictive\n", + "│ Dimensions: (chain: 4, draw: 200, time: 100, state: 3,\n", + "│ state_aux: 3)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 2kB 0 1 2 3 4 5 ... 195 196 197 198 199\n", + "│ * time (time) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99\n", + "│ * state (state) Date: Fri, 22 May 2026 12:20:18 -0600 Subject: [PATCH 05/10] updated sample_statespace_matrices to support batched data --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 128 +++++++----------- pymc_extras/statespace/core/statespace.py | 10 +- 2 files changed, 57 insertions(+), 81 deletions(-) diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb index 042c20eb2..88a821a03 100644 --- a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -13,12 +13,12 @@ "* `sample_unconditional_posterior`\n", "* `sample_conditional_posterior`\n", "* `forecast`\n", + "* `sample_filter_outputs`\n", "\n", "\n", "Need to also add\n", "\n", "* `sample_statespace_matrices`\n", - "* `sample_filter_outputs`\n", "* `impulse_response_function`" ] }, @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "id": "34b32aa3", "metadata": {}, "outputs": [], @@ -59,13 +59,13 @@ " k_states = 3 # size of the state vector x\n", " k_posdef = 1 # number of shocks (size of the state covariance matrix Q)\n", " k_endog = 1 # number of observed states\n", - " # batch_size = 2\n", + " batch_size = 2\n", "\n", " super().__init__(\n", " k_endog=k_endog,\n", " k_states=k_states,\n", " k_posdef=k_posdef,\n", - " # batch_size=batch_size,\n", + " batch_size=batch_size,\n", " mode=mode,\n", " )\n", "\n", @@ -127,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "id": "bfae06f5", "metadata": {}, "outputs": [ @@ -171,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "id": "27119d27", "metadata": {}, "outputs": [], @@ -181,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "id": "d9163bda", "metadata": {}, "outputs": [], @@ -191,19 +191,10 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "217b812d", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", - " warnings.warn(NO_TIME_INDEX_WARNING)\n" - ] - } - ], + "outputs": [], "source": [ "# Not vectorized\n", "with pm.Model(coords=ar3.coords) as pymc_mod:\n", @@ -247,7 +238,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 7, "id": "ed208bcd", "metadata": {}, "outputs": [ @@ -263,7 +254,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "51570fe8087f4c28aaf4ef7662d4e488", + "model_id": "29521a3c6fce4dbf83a90cb852930b6f", "version_major": 2, "version_minor": 0 }, @@ -288,7 +279,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 2 seconds.\n" + "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 7 seconds.\n", + "The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n" ] } ], @@ -299,23 +291,21 @@ }, { "cell_type": "code", - "execution_count": 15, - "id": "64940705", + "execution_count": 8, + "id": "24035cf0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", - " warnings.warn(NO_TIME_INDEX_WARNING)\n", "Sampling: []\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "32292b1a794e4fa7be31153fb56b19ea", + "model_id": "5662f8c19923423399bd20e1a3fced0a", "version_major": 2, "version_minor": 0 }, @@ -884,89 +874,57 @@ "
<xarray.DataTree>\n",
        "Group: /\n",
        "├── Group: /posterior_predictive\n",
-       "│       Dimensions:               (chain: 4, draw: 200, time: 100, state: 3,\n",
-       "│                                  state_aux: 3)\n",
+       "│       Dimensions:    (chain: 4, draw: 200, batch: 2, state: 3, state_aux: 3)\n",
        "│       Coordinates:\n",
-       "│         * chain                 (chain) int64 32B 0 1 2 3\n",
-       "│         * draw                  (draw) int64 2kB 0 1 2 3 4 5 ... 195 196 197 198 199\n",
-       "│         * time                  (time) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99\n",
-       "│         * state                 (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "│         * state_aux             (state_aux) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "│         * chain      (chain) int64 32B 0 1 2 3\n",
+       "│         * draw       (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
+       "│         * batch      (batch) <U7 56B 'batch_1' 'batch_2'\n",
+       "│         * state      (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "│         * state_aux  (state_aux) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
        "│       Data variables:\n",
-       "│           filtered_covariances  (chain, draw, time, state, state_aux) float64 6MB 1...\n",
+       "│           T          (chain, draw, batch, state, state_aux) float64 115kB -0.00538 ...\n",
        "│       Attributes:\n",
-       "│           created_at:                 2026-05-22T17:38:10.782806+00:00\n",
+       "│           created_at:                 2026-05-22T18:15:15.436712+00:00\n",
        "│           creation_library:           ArviZ\n",
        "│           creation_library_version:   1.1.0\n",
        "│           creation_library_language:  Python\n",
        "│           inference_library:          pymc\n",
        "│           inference_library_version:  6.0.0\n",
        "│           sample_dims:                ['chain', 'draw']\n",
-       "├── Group: /observed_data\n",
-       "│       Attributes:\n",
-       "│           created_at:                 2026-05-22T17:38:10.784674+00:00\n",
-       "│           creation_library:           ArviZ\n",
-       "│           creation_library_version:   1.1.0\n",
-       "│           creation_library_language:  Python\n",
-       "│           inference_library:          pymc\n",
-       "│           inference_library_version:  6.0.0\n",
-       "│           sample_dims:                []\n",
-       "└── Group: /constant_data\n",
-       "        Dimensions:         (time: 100, observed_state: 1)\n",
-       "        Coordinates:\n",
-       "          * time            (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n",
-       "          * observed_state  (observed_state) <U3 12B 'ts1'\n",
-       "        Data variables:\n",
-       "            data            (time, observed_state) float64 800B 0.0816 ... 0.4303\n",
+       "└── Group: /observed_data\n",
        "        Attributes:\n",
-       "            created_at:                 2026-05-22T17:38:10.785346+00:00\n",
+       "            created_at:                 2026-05-22T18:15:15.438048+00:00\n",
        "            creation_library:           ArviZ\n",
        "            creation_library_version:   1.1.0\n",
        "            creation_library_language:  Python\n",
        "            inference_library:          pymc\n",
        "            inference_library_version:  6.0.0\n",
-       "            sample_dims:                []
" + " sample_dims: []" ], "text/plain": [ "\n", "Group: /\n", "├── Group: /posterior_predictive\n", - "│ Dimensions: (chain: 4, draw: 200, time: 100, state: 3,\n", - "│ state_aux: 3)\n", + "│ Dimensions: (chain: 4, draw: 200, batch: 2, state: 3, state_aux: 3)\n", "│ Coordinates:\n", - "│ * chain (chain) int64 32B 0 1 2 3\n", - "│ * draw (draw) int64 2kB 0 1 2 3 4 5 ... 195 196 197 198 199\n", - "│ * time (time) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99\n", - "│ * state (state) Date: Fri, 22 May 2026 13:17:47 -0600 Subject: [PATCH 06/10] updated impulse_response_function to support batched data --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 202 +++++++++++------- pymc_extras/statespace/core/statespace.py | 67 ++++-- 2 files changed, 182 insertions(+), 87 deletions(-) diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb index 88a821a03..fcb454968 100644 --- a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 9, "id": "34b32aa3", "metadata": {}, "outputs": [], @@ -127,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 10, "id": "bfae06f5", "metadata": {}, "outputs": [ @@ -171,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 11, "id": "27119d27", "metadata": {}, "outputs": [], @@ -181,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 12, "id": "d9163bda", "metadata": {}, "outputs": [], @@ -191,10 +191,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "217b812d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", + " warnings.warn(NO_TIME_INDEX_WARNING)\n" + ] + } + ], "source": [ "# Not vectorized\n", "with pm.Model(coords=ar3.coords) as pymc_mod:\n", @@ -209,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 13, "id": "84ac6e46", "metadata": {}, "outputs": [ @@ -238,7 +247,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 14, "id": "ed208bcd", "metadata": {}, "outputs": [ @@ -254,7 +263,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "29521a3c6fce4dbf83a90cb852930b6f", + "model_id": "e937a354e8234dc2b633e7b2af3d6b1e", "version_major": 2, "version_minor": 0 }, @@ -279,8 +288,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 7 seconds.\n", - "The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details\n" + "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 7 seconds.\n" ] } ], @@ -291,21 +299,21 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "24035cf0", + "execution_count": 15, + "id": "c172bc3a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Sampling: []\n" + "Sampling: [initial_shock]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5662f8c19923423399bd20e1a3fced0a", + "model_id": "9b4037fed60b483d8b4286c19986e634", "version_major": 2, "version_minor": 0 }, @@ -871,73 +879,119 @@ " filter: drop-shadow(1px 1px 5px var(--xr-font-color2));\n", " stroke-width: 0.8px;\n", "}\n", - "
<xarray.DataTree>\n",
-       "Group: /\n",
-       "├── Group: /posterior_predictive\n",
-       "│       Dimensions:    (chain: 4, draw: 200, batch: 2, state: 3, state_aux: 3)\n",
-       "│       Coordinates:\n",
-       "│         * chain      (chain) int64 32B 0 1 2 3\n",
-       "│         * draw       (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
-       "│         * batch      (batch) <U7 56B 'batch_1' 'batch_2'\n",
-       "│         * state      (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "│         * state_aux  (state_aux) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "│       Data variables:\n",
-       "│           T          (chain, draw, batch, state, state_aux) float64 115kB -0.00538 ...\n",
-       "│       Attributes:\n",
-       "│           created_at:                 2026-05-22T18:15:15.436712+00:00\n",
-       "│           creation_library:           ArviZ\n",
-       "│           creation_library_version:   1.1.0\n",
-       "│           creation_library_language:  Python\n",
-       "│           inference_library:          pymc\n",
-       "│           inference_library_version:  6.0.0\n",
-       "│           sample_dims:                ['chain', 'draw']\n",
-       "└── Group: /observed_data\n",
-       "        Attributes:\n",
-       "            created_at:                 2026-05-22T18:15:15.438048+00:00\n",
-       "            creation_library:           ArviZ\n",
-       "            creation_library_version:   1.1.0\n",
-       "            creation_library_language:  Python\n",
-       "            inference_library:          pymc\n",
-       "            inference_library_version:  6.0.0\n",
-       "            sample_dims:                []
" + "
<xarray.DataTree 'posterior_predictive'>\n",
+       "Group: /posterior_predictive\n",
+       "    Dimensions:  (chain: 4, draw: 200, batch: 2, time: 40, state: 3)\n",
+       "    Coordinates:\n",
+       "      * chain    (chain) int64 32B 0 1 2 3\n",
+       "      * draw     (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
+       "      * batch    (batch) <U7 56B 'batch_1' 'batch_2'\n",
+       "      * time     (time) int64 320B 0 1 2 3 4 5 6 7 8 ... 31 32 33 34 35 36 37 38 39\n",
+       "      * state    (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "    Data variables:\n",
+       "        irf      (chain, draw, batch, time, state) float64 2MB 0.2783 ... -2.115e-10\n",
+       "    Attributes:\n",
+       "        created_at:                 2026-05-22T19:12:13.521657+00:00\n",
+       "        creation_library:           ArviZ\n",
+       "        creation_library_version:   1.1.0\n",
+       "        creation_library_language:  Python\n",
+       "        inference_library:          pymc\n",
+       "        inference_library_version:  6.0.0\n",
+       "        sample_dims:                ['chain', 'draw']
" ], "text/plain": [ - "\n", - "Group: /\n", - "├── Group: /posterior_predictive\n", - "│ Dimensions: (chain: 4, draw: 200, batch: 2, state: 3, state_aux: 3)\n", - "│ Coordinates:\n", - "│ * chain (chain) int64 32B 0 1 2 3\n", - "│ * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n", - "│ * batch (batch) \n", + "Group: /posterior_predictive\n", + " Dimensions: (chain: 4, draw: 200, batch: 2, time: 40, state: 3)\n", + " Coordinates:\n", + " * chain (chain) int64 32B 0 1 2 3\n", + " * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n", + " * batch (batch) Date: Tue, 26 May 2026 06:13:36 -0600 Subject: [PATCH 07/10] refactored vectorization methods, pulled out repeated helpers into batch_tools utilities file, use slicing logic instead of branching for forecast indexing, make batch_size a tuple for destructuring instead of branching --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 956 +----------------- pymc_extras/statespace/core/statespace.py | 177 ++-- .../statespace/filters/distributions.py | 5 +- pymc_extras/statespace/utils/batch_tools.py | 5 + 4 files changed, 106 insertions(+), 1037 deletions(-) create mode 100644 pymc_extras/statespace/utils/batch_tools.py diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb index fcb454968..94eff8375 100644 --- a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -14,17 +14,13 @@ "* `sample_conditional_posterior`\n", "* `forecast`\n", "* `sample_filter_outputs`\n", - "\n", - "\n", - "Need to also add\n", - "\n", "* `sample_statespace_matrices`\n", - "* `impulse_response_function`" + "* `impulse_response_function`\n" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d95e629f", "metadata": {}, "outputs": [], @@ -49,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "34b32aa3", "metadata": {}, "outputs": [], @@ -127,51 +123,17 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "id": "bfae06f5", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
                          Model Requirements                           \n",
-       "                                                                       \n",
-       "  Variable    Shape    Constraints                         Dimensions  \n",
-       " ───────────────────────────────────────────────────────────────────── \n",
-       "  x0          (3,)                                         ('state',)  \n",
-       "  P0          (3, 3)                           ('state', 'state_aux')  \n",
-       "  ar_params   (3,)     Stationary, please :)             ('ar_lags',)  \n",
-       "  sigma_x     (1,)                                         ('shock',)  \n",
-       "                                                                       \n",
-       " These parameters should be assigned priors inside a PyMC model block  \n",
-       "           before calling the build_statespace_graph method.           \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[3m Model Requirements \u001b[0m\n", - " \n", - " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", - " ───────────────────────────────────────────────────────────────────── \n", - " x0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m,\u001b[1m)\u001b[0m \n", - " P0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", - " ar_params \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m Stationary, please :\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'ar_lags'\u001b[0m,\u001b[1m)\u001b[0m \n", - " sigma_x \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'shock'\u001b[0m,\u001b[1m)\u001b[0m \n", - " \n", - "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block \u001b[0m\n", - "\u001b[2;3m before calling the build_statespace_graph method. \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "ar3 = AutoRegressiveThree(mode=\"NUMBA\")" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "27119d27", "metadata": {}, "outputs": [], @@ -181,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "d9163bda", "metadata": {}, "outputs": [], @@ -191,19 +153,10 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "217b812d", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", - " warnings.warn(NO_TIME_INDEX_WARNING)\n" - ] - } - ], + "outputs": [], "source": [ "# Not vectorized\n", "with pm.Model(coords=ar3.coords) as pymc_mod:\n", @@ -218,19 +171,10 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "84ac6e46", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", - " warnings.warn(NO_TIME_INDEX_WARNING)\n" - ] - } - ], + "outputs": [], "source": [ "# Vectorized\n", "with pm.Model(coords=ar3.coords | {\"batch\": [\"batch_1\", \"batch_2\"]}) as pymc_mod:\n", @@ -247,51 +191,10 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "ed208bcd", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Initializing NUTS using jitter+adapt_diag...\n", - "Multiprocess sampling (4 chains in 4 jobs)\n", - "NUTS: [ar_params, sigma_x]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e937a354e8234dc2b633e7b2af3d6b1e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 7 seconds.\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "with pymc_mod:\n",
     "    idata = pm.sample(tune=200, draws=200, compile_kwargs={\"mode\": \"NUMBA\"})"
@@ -299,689 +202,10 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 15,
+   "execution_count": null,
    "id": "c172bc3a",
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Sampling: [initial_shock]\n"
-     ]
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "9b4037fed60b483d8b4286c19986e634",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree 'posterior_predictive'>\n",
-       "Group: /posterior_predictive\n",
-       "    Dimensions:  (chain: 4, draw: 200, batch: 2, time: 40, state: 3)\n",
-       "    Coordinates:\n",
-       "      * chain    (chain) int64 32B 0 1 2 3\n",
-       "      * draw     (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
-       "      * batch    (batch) <U7 56B 'batch_1' 'batch_2'\n",
-       "      * time     (time) int64 320B 0 1 2 3 4 5 6 7 8 ... 31 32 33 34 35 36 37 38 39\n",
-       "      * state    (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "    Data variables:\n",
-       "        irf      (chain, draw, batch, time, state) float64 2MB 0.2783 ... -2.115e-10\n",
-       "    Attributes:\n",
-       "        created_at:                 2026-05-22T19:12:13.521657+00:00\n",
-       "        creation_library:           ArviZ\n",
-       "        creation_library_version:   1.1.0\n",
-       "        creation_library_language:  Python\n",
-       "        inference_library:          pymc\n",
-       "        inference_library_version:  6.0.0\n",
-       "        sample_dims:                ['chain', 'draw']
" - ], - "text/plain": [ - "\n", - "Group: /posterior_predictive\n", - " Dimensions: (chain: 4, draw: 200, batch: 2, time: 40, state: 3)\n", - " Coordinates:\n", - " * chain (chain) int64 32B 0 1 2 3\n", - " * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n", - " * batch (batch) {output_sig}" + + def _vectorize(self, fn, inputs, outputs, time_varying_names): + signature = self._build_signature( + inputs, + outputs, + time_varying_names, + ) + + return pt.vectorize(fn, signature=signature) + def make_vectorized_filter( self, missing_fill_value, cov_jitter, time_varying_names, ): - def maybe_tv(name, base_sig): - """ - Add leading time dimension if matrix is time varying. - """ - if name in time_varying_names: - return f"(t,{base_sig})" - return f"({base_sig})" - def vectorize_filter(data, x0, P0, c, d, T, Z, R, H, Q): return self.kalman_filter.build_graph( data, @@ -1053,49 +1077,37 @@ def vectorize_filter(data, x0, P0, c, d, T, Z, R, H, Q): time_varying_names, ) - input_signature = ",".join( - [ - "(t,o)", # data - "(k)", # Initial state - "(k,k)", # Initial cov - maybe_tv("c", "k"), # state eq. intercept - maybe_tv("d", "o"), # observation eq. intercept - maybe_tv("T", "k,k"), # transition - maybe_tv("Z", "o,k"), # design - maybe_tv("R", "k,r"), # selection - maybe_tv("H", "o,o"), # observation cov - maybe_tv("Q", "r,r"), # process cov - ] - ) - - output_signature = ",".join( - [ - "(t,k)", # filtered states - "(t,k)", # predicted states - "(t,o)", # forecasts - "(t,k,k)", # filtered covs - "(t,k,k)", # predicted covs - "(t,o,o)", # forecast covs - "(t)", # loglikelihoods - ] - ) + inputs = [ + (None, "t,o"), # data + (None, "k"), # x0 + (None, "k,k"), # P0 + ("c", "k"), + ("d", "o"), + ("T", "k,k"), + ("Z", "o,k"), + ("R", "k,r"), + ("H", "o,o"), + ("Q", "r,r"), + ] - signature = f"{input_signature}->{output_signature}" + outputs = [ + "t,k", + "t,k", + "t,o", + "t,k,k", + "t,k,k", + "t,o,o", + "t", + ] - return pt.vectorize( + return self._vectorize( vectorize_filter, - signature=signature, + inputs, + outputs, + time_varying_names, ) def make_vectorized_smoother(self, cov_jitter, time_varying_names): - def maybe_tv(name, base_sig): - """ - Add leading time dimension if matrix is time varying. - """ - if name in time_varying_names: - return f"(t,{base_sig})" - return f"({base_sig})" - def vectorize_smoother(T, R, Q, filtered_states, filtered_covariances): return self.kalman_smoother.build_graph( T, @@ -1107,28 +1119,24 @@ def vectorize_smoother(T, R, Q, filtered_states, filtered_covariances): time_varying_names, ) - input_signature = ",".join( - [ - maybe_tv("T", "k,k"), # transition - maybe_tv("R", "k,r"), # selection - maybe_tv("Q", "r,r"), # process cov - "(t, k)", # filtered_states - "(t, k, k)", # filtered_covariances - ] - ) - - output_signature = ",".join( - [ - "(t,k)", # smoothed states - "(t,k,k)", # smoothed covs - ] - ) + inputs = [ + ("T", "k,k"), + ("R", "k,r"), + ("Q", "r,r"), + (None, "t,k"), + (None, "t,k,k"), + ] - signature = f"{input_signature}->{output_signature}" + outputs = [ + "t,k", + "t,k,k", + ] - return pt.vectorize( + return self._vectorize( vectorize_smoother, - signature=signature, + inputs, + outputs, + time_varying_names, ) def build_statespace_graph( @@ -1355,8 +1363,8 @@ def _build_dummy_graph(self) -> None: def infer_variable_shape(name): shape = self._name_to_variable[name].type.shape - if self.batch_size: - shape = (self.batch_size, *shape) + shape = (*self.batch_size, *shape) + if not any(dim is None for dim in shape): return shape @@ -2616,6 +2624,7 @@ def _build_forecast_model( temp_coords = self._fit_coords.copy() t0_idx = np.flatnonzero(time_index == t0)[0] + idx = (slice(None), t0_idx) if self.batch_size else (t0_idx,) temp_coords["data_time"] = time_index temp_coords[TIME_DIM] = forecast_index @@ -2665,24 +2674,12 @@ def _build_forecast_model( mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True) - if self.batch_size: - x0 = pm.Deterministic( - "x0_slice", - mu_frozen[:, t0_idx, :], - dims=mu_dims if mu_dims is not None else None, - ) - P0 = pm.Deterministic( - "P0_slice", - cov_frozen[:, t0_idx, :, :], - dims=cov_dims if cov_dims is not None else None, - ) - else: - x0 = pm.Deterministic( - "x0_slice", mu_frozen[t0_idx], dims=mu_dims if mu_dims is not None else None - ) - P0 = pm.Deterministic( - "P0_slice", cov_frozen[t0_idx], dims=cov_dims if cov_dims is not None else None - ) + x0 = pm.Deterministic( + "x0_slice", mu_frozen[idx], dims=mu_dims if mu_dims is not None else None + ) + P0 = pm.Deterministic( + "P0_slice", cov_frozen[idx], dims=cov_dims if cov_dims is not None else None + ) # Get fresh matrices with n_timesteps placeholder still intact. # Build for the full timeline (training + forecast) so that time-varying matrices @@ -3037,7 +3034,7 @@ def impulse_response_function( if self.batch_size: x0 = pm.Deterministic( "x0_new", - pt.zeros((self.batch_size, self.k_states)), + pt.zeros((*self.batch_size, self.k_states)), dims=[BATCH_DIM, ALL_STATE_DIM], ) else: @@ -3053,10 +3050,7 @@ def impulse_response_function( Q = pt.linalg.cholesky(Q) / pt.diag(Q) if shock_trajectory is None: - if self.batch_size: - shock_trajectory = pt.zeros((self.batch_size, n_steps, self.k_posdef)) - else: - shock_trajectory = pt.zeros((n_steps, self.k_posdef)) + shock_trajectory = pt.zeros((*self.batch_size, n_steps, self.k_posdef)) if Q is not None: if self.batch_size: init_shock = pm.MvNormal( @@ -3096,9 +3090,6 @@ def impulse_response_function( if self.batch_size: shock_trajectory = shock_trajectory.swapaxes(0, 1) - def bmv(A, x): - return pt.matmul(A, x[..., None])[..., 0] - def irf_step(*args): if time_varying_T: shock, T, x, c, R = args diff --git a/pymc_extras/statespace/filters/distributions.py b/pymc_extras/statespace/filters/distributions.py index d291033c0..7c549647c 100644 --- a/pymc_extras/statespace/filters/distributions.py +++ b/pymc_extras/statespace/filters/distributions.py @@ -10,6 +10,8 @@ from pytensor.graph.basic import Node from pytensor.tensor.random import multivariate_normal +from pymc_extras.statespace.utils.batch_tools import bmv + floatX = pytensor.config.floatX COV_ZERO_TOL = 0 @@ -157,9 +159,6 @@ def rv_op( append_x0=True, method="svd", ): - def bmv(A, x): - return pt.matmul(A, x[..., None])[..., 0] - if sequence_names is None: sequence_names = [] diff --git a/pymc_extras/statespace/utils/batch_tools.py b/pymc_extras/statespace/utils/batch_tools.py new file mode 100644 index 000000000..49b8a70d9 --- /dev/null +++ b/pymc_extras/statespace/utils/batch_tools.py @@ -0,0 +1,5 @@ +import pytensor.tensor as pt + + +def bmv(A, x): + return pt.matmul(A, x[..., None])[..., 0] From 7b0c1791b357d9f69494e7305e75971ca72f72fd Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Thu, 28 May 2026 18:42:03 -0600 Subject: [PATCH 08/10] always call vectorize_graph, build_statespace is entrypoint for batch_sizes, added a utility to infer the batch dimension --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 5380 ++++++++++++++++- pymc_extras/statespace/core/statespace.py | 76 +- 2 files changed, 5410 insertions(+), 46 deletions(-) diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb index 94eff8375..0e56ea56b 100644 --- a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "d95e629f", "metadata": {}, "outputs": [], @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "34b32aa3", "metadata": {}, "outputs": [], @@ -55,13 +55,13 @@ " k_states = 3 # size of the state vector x\n", " k_posdef = 1 # number of shocks (size of the state covariance matrix Q)\n", " k_endog = 1 # number of observed states\n", - " batch_size = 2\n", + " # batch_size = 2\n", "\n", " super().__init__(\n", " k_endog=k_endog,\n", " k_states=k_states,\n", " k_posdef=k_posdef,\n", - " batch_size=batch_size,\n", + " # batch_size=batch_size,\n", " mode=mode,\n", " )\n", "\n", @@ -123,17 +123,51 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "bfae06f5", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
                          Model Requirements                           \n",
+       "                                                                       \n",
+       "  Variable    Shape    Constraints                         Dimensions  \n",
+       " ───────────────────────────────────────────────────────────────────── \n",
+       "  x0          (3,)                                         ('state',)  \n",
+       "  P0          (3, 3)                           ('state', 'state_aux')  \n",
+       "  ar_params   (3,)     Stationary, please :)             ('ar_lags',)  \n",
+       "  sigma_x     (1,)                                         ('shock',)  \n",
+       "                                                                       \n",
+       " These parameters should be assigned priors inside a PyMC model block  \n",
+       "           before calling the build_statespace_graph method.           \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[3m Model Requirements \u001b[0m\n", + " \n", + " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", + " ───────────────────────────────────────────────────────────────────── \n", + " x0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m,\u001b[1m)\u001b[0m \n", + " P0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", + " ar_params \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m Stationary, please :\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'ar_lags'\u001b[0m,\u001b[1m)\u001b[0m \n", + " sigma_x \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'shock'\u001b[0m,\u001b[1m)\u001b[0m \n", + " \n", + "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block \u001b[0m\n", + "\u001b[2;3m before calling the build_statespace_graph method. \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "ar3 = AutoRegressiveThree(mode=\"NUMBA\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "27119d27", "metadata": {}, "outputs": [], @@ -143,7 +177,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "d9163bda", "metadata": {}, "outputs": [], @@ -153,10 +187,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "217b812d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", + " warnings.warn(NO_TIME_INDEX_WARNING)\n" + ] + } + ], "source": [ "# Not vectorized\n", "with pm.Model(coords=ar3.coords) as pymc_mod:\n", @@ -171,10 +214,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "84ac6e46", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", + " warnings.warn(NO_TIME_INDEX_WARNING)\n" + ] + } + ], "source": [ "# Vectorized\n", "with pm.Model(coords=ar3.coords | {\"batch\": [\"batch_1\", \"batch_2\"]}) as pymc_mod:\n", @@ -192,9 +244,173 @@ { "cell_type": "code", "execution_count": null, - "id": "ed208bcd", + "id": "55b1b63c", + "metadata": {}, + "outputs": [], + "source": [ + "matrices = ar3.unpack_statespace()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bc22797", + "metadata": {}, + "outputs": [], + "source": [ + "matrices_2 = ar3._unpack_statespace_with_placeholders()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9dbac19", + "metadata": {}, + "outputs": [], + "source": [ + "[(m.name, m.type.shape) for m in matrices_2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3aa195b", + "metadata": {}, + "outputs": [], + "source": [ + "[(m.name, m.type.shape) for m in matrices]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e4e79de", + "metadata": {}, + "outputs": [], + "source": [ + "type(matrices[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d302f051", "metadata": {}, "outputs": [], + "source": [ + "def infer_batch_dimensions(\n", + " core_matrices: list[pt.TensorVariable],\n", + " subbed_matrices: list[pt.TensorVariable],\n", + ") -> tuple[int | None, ...]:\n", + " inferred_batch_dims = ()\n", + "\n", + " for core_matrix, sub_matrix in zip(core_matrices, subbed_matrices):\n", + " core_shape = core_matrix.type.shape\n", + " sub_shape = sub_matrix.type.shape\n", + "\n", + " if len(sub_shape) < len(core_shape):\n", + " raise ValueError(\n", + " f\"Subbed matrix has fewer dims than core matrix: \" f\"{sub_shape} vs {core_shape}\"\n", + " )\n", + "\n", + " # Verify trailing/core dimensions match\n", + " trailing_shape = sub_shape[-len(core_shape) :]\n", + "\n", + " for core_dim, sub_dim in zip(core_shape, trailing_shape):\n", + " if core_dim is not None and sub_dim is not None and core_dim != sub_dim:\n", + " raise ValueError(f\"Core dimension mismatch: \" f\"{core_shape} vs {sub_shape}\")\n", + "\n", + " batch_dims = sub_shape[: -len(core_shape)]\n", + "\n", + " # Skip matrices with no batch dimensions\n", + " if len(batch_dims) == 0:\n", + " continue\n", + "\n", + " # First batched tensor establishes the batch shape\n", + " if len(inferred_batch_dims) == 0:\n", + " inferred_batch_dims = batch_dims\n", + " continue\n", + "\n", + " # Validate consistency\n", + " if len(batch_dims) != len(inferred_batch_dims):\n", + " raise ValueError(f\"Inconsistent batch rank: \" f\"{batch_dims} vs {inferred_batch_dims}\")\n", + "\n", + " merged_dims = []\n", + "\n", + " for inferred_dim, new_dim in zip(inferred_batch_dims, batch_dims):\n", + " if inferred_dim is None:\n", + " merged_dims.append(new_dim)\n", + " elif new_dim is None:\n", + " merged_dims.append(inferred_dim)\n", + " elif inferred_dim == new_dim:\n", + " merged_dims.append(inferred_dim)\n", + " else:\n", + " raise ValueError(\n", + " f\"Inconsistent batch dimensions: \" f\"{batch_dims} vs {inferred_batch_dims}\"\n", + " )\n", + "\n", + " inferred_batch_dims = tuple(merged_dims)\n", + "\n", + " return inferred_batch_dims" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "257c1a2b", + "metadata": {}, + "outputs": [], + "source": [ + "infer_batch_dimensions(core_matrices=matrices_2, subbed_matrices=matrices)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ed208bcd", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [ar_params, sigma_x]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f844cca998fe40b59f9a78d011c634c9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 6 seconds.\n"
+     ]
+    }
+   ],
    "source": [
     "with pymc_mod:\n",
     "    idata = pm.sample(tune=200, draws=200, compile_kwargs={\"mode\": \"NUMBA\"})"
@@ -202,60 +418,2831 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 8,
    "id": "c172bc3a",
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Sampling: [initial_shock]\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "68b3dc4111e5462596fbffec03183126",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree 'posterior_predictive'>\n",
+       "Group: /posterior_predictive\n",
+       "    Dimensions:  (chain: 4, draw: 200, batch: 2, time: 40, state: 3)\n",
+       "    Coordinates:\n",
+       "      * chain    (chain) int64 32B 0 1 2 3\n",
+       "      * draw     (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
+       "      * batch    (batch) <U7 56B 'batch_1' 'batch_2'\n",
+       "      * time     (time) int64 320B 0 1 2 3 4 5 6 7 8 ... 31 32 33 34 35 36 37 38 39\n",
+       "      * state    (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "    Data variables:\n",
+       "        irf      (chain, draw, batch, time, state) float64 2MB 2.659 ... 4.025e-13\n",
+       "    Attributes:\n",
+       "        created_at:                 2026-05-29T00:26:42.616309+00:00\n",
+       "        creation_library:           ArviZ\n",
+       "        creation_library_version:   1.1.0\n",
+       "        creation_library_language:  Python\n",
+       "        inference_library:          pymc\n",
+       "        inference_library_version:  6.0.0\n",
+       "        sample_dims:                ['chain', 'draw']
" + ], + "text/plain": [ + "\n", + "Group: /posterior_predictive\n", + " Dimensions: (chain: 4, draw: 200, batch: 2, time: 40, state: 3)\n", + " Coordinates:\n", + " * chain (chain) int64 32B 0 1 2 3\n", + " * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n", + " * batch (batch) \n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree>\n",
+       "Group: /\n",
+       "├── Group: /posterior_predictive\n",
+       "│       Dimensions:    (chain: 4, draw: 200, batch: 2, state: 3, state_aux: 3)\n",
+       "│       Coordinates:\n",
+       "│         * chain      (chain) int64 32B 0 1 2 3\n",
+       "│         * draw       (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
+       "│         * batch      (batch) <U7 56B 'batch_1' 'batch_2'\n",
+       "│         * state      (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "│         * state_aux  (state_aux) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "│       Data variables:\n",
+       "│           T          (chain, draw, batch, state, state_aux) float64 115kB 0.06838 ....\n",
+       "│       Attributes:\n",
+       "│           created_at:                 2026-05-29T00:26:44.280981+00:00\n",
+       "│           creation_library:           ArviZ\n",
+       "│           creation_library_version:   1.1.0\n",
+       "│           creation_library_language:  Python\n",
+       "│           inference_library:          pymc\n",
+       "│           inference_library_version:  6.0.0\n",
+       "│           sample_dims:                ['chain', 'draw']\n",
+       "└── Group: /observed_data\n",
+       "        Attributes:\n",
+       "            created_at:                 2026-05-29T00:26:44.282187+00:00\n",
+       "            creation_library:           ArviZ\n",
+       "            creation_library_version:   1.1.0\n",
+       "            creation_library_language:  Python\n",
+       "            inference_library:          pymc\n",
+       "            inference_library_version:  6.0.0\n",
+       "            sample_dims:                []
" + ], + "text/plain": [ + "\n", + "Group: /\n", + "├── Group: /posterior_predictive\n", + "│ Dimensions: (chain: 4, draw: 200, batch: 2, state: 3, state_aux: 3)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n", + "│ * batch (batch) \n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree>\n",
+       "Group: /\n",
+       "├── Group: /posterior_predictive\n",
+       "│       Dimensions:               (chain: 4, draw: 200, batch: 2, time: 100, state: 3,\n",
+       "│                                  state_aux: 3)\n",
+       "│       Coordinates:\n",
+       "│         * chain                 (chain) int64 32B 0 1 2 3\n",
+       "│         * draw                  (draw) int64 2kB 0 1 2 3 4 5 ... 195 196 197 198 199\n",
+       "│         * batch                 (batch) int64 16B 0 1\n",
+       "│         * time                  (time) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99\n",
+       "│         * state                 (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "│         * state_aux             (state_aux) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "│       Data variables:\n",
+       "│           filtered_covariances  (chain, draw, batch, time, state, state_aux) float64 12MB ...\n",
+       "│       Attributes:\n",
+       "│           created_at:                 2026-05-29T00:26:48.352615+00:00\n",
+       "│           creation_library:           ArviZ\n",
+       "│           creation_library_version:   1.1.0\n",
+       "│           creation_library_language:  Python\n",
+       "│           inference_library:          pymc\n",
+       "│           inference_library_version:  6.0.0\n",
+       "│           sample_dims:                ['chain', 'draw']\n",
+       "├── Group: /observed_data\n",
+       "│       Attributes:\n",
+       "│           created_at:                 2026-05-29T00:26:48.353653+00:00\n",
+       "│           creation_library:           ArviZ\n",
+       "│           creation_library_version:   1.1.0\n",
+       "│           creation_library_language:  Python\n",
+       "│           inference_library:          pymc\n",
+       "│           inference_library_version:  6.0.0\n",
+       "│           sample_dims:                []\n",
+       "└── Group: /constant_data\n",
+       "        Dimensions:         (batch: 2, time: 100, observed_state: 1)\n",
+       "        Coordinates:\n",
+       "          * batch           (batch) int64 16B 0 1\n",
+       "          * time            (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n",
+       "          * observed_state  (observed_state) <U3 12B 'ts1'\n",
+       "        Data variables:\n",
+       "            data            (batch, time, observed_state) float64 2kB -0.9498 ... 0.3309\n",
+       "        Attributes:\n",
+       "            created_at:                 2026-05-29T00:26:48.354952+00:00\n",
+       "            creation_library:           ArviZ\n",
+       "            creation_library_version:   1.1.0\n",
+       "            creation_library_language:  Python\n",
+       "            inference_library:          pymc\n",
+       "            inference_library_version:  6.0.0\n",
+       "            sample_dims:                []
" + ], + "text/plain": [ + "\n", + "Group: /\n", + "├── Group: /posterior_predictive\n", + "│ Dimensions: (chain: 4, draw: 200, batch: 2, time: 100, state: 3,\n", + "│ state_aux: 3)\n", + "│ Coordinates:\n", + "│ * chain (chain) int64 32B 0 1 2 3\n", + "│ * draw (draw) int64 2kB 0 1 2 3 4 5 ... 195 196 197 198 199\n", + "│ * batch (batch) int64 16B 0 1\n", + "│ * time (time) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99\n", + "│ * state (state) \n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree 'posterior_predictive'>\n",
+       "Group: /posterior_predictive\n",
+       "    Dimensions:            (chain: 4, draw: 200, time: 10, batch: 2, state: 3,\n",
+       "                            observed_state: 1)\n",
+       "    Coordinates:\n",
+       "      * chain              (chain) int64 32B 0 1 2 3\n",
+       "      * draw               (draw) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199\n",
+       "      * time               (time) int64 80B 100 101 102 103 104 105 106 107 108 109\n",
+       "      * batch              (batch) <U7 56B 'batch_1' 'batch_2'\n",
+       "      * state              (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "      * observed_state     (observed_state) <U3 12B 'ts1'\n",
+       "    Data variables:\n",
+       "        forecast_latent    (chain, draw, time, batch, state) float64 384kB -1.179...\n",
+       "        forecast_observed  (chain, draw, time, batch, observed_state) float64 128kB ...\n",
+       "    Attributes:\n",
+       "        created_at:                 2026-05-29T00:26:56.955319+00:00\n",
+       "        creation_library:           ArviZ\n",
+       "        creation_library_version:   1.1.0\n",
+       "        creation_library_language:  Python\n",
+       "        inference_library:          pymc\n",
+       "        inference_library_version:  6.0.0\n",
+       "        sample_dims:                ['chain', 'draw']
" + ], + "text/plain": [ + "\n", + "Group: /posterior_predictive\n", + " Dimensions: (chain: 4, draw: 200, time: 10, batch: 2, state: 3,\n", + " observed_state: 1)\n", + " Coordinates:\n", + " * chain (chain) int64 32B 0 1 2 3\n", + " * draw (draw) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199\n", + " * time (time) int64 80B 100 101 102 103 104 105 106 107 108 109\n", + " * batch (batch) \n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "post = ar3.sample_conditional_posterior(idata, mvn_method=\"cholesky\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "5c3800a9", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [ar_params, obs, sigma_x]\n" + ] + } + ], "source": [ "with pymc_mod:\n", " prior = pm.sample_prior_predictive(compile_kwargs={\"mode\": \"NUMBA\"})" @@ -263,40 +3250,2367 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "id": "5857a994", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", + " warnings.warn(NO_TIME_INDEX_WARNING)\n", + "Sampling: [filtered_prior, filtered_prior_observed, predicted_prior, predicted_prior_observed, smoothed_prior, smoothed_prior_observed]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "69a5f40528374795b2144353ba1f7dfe", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree 'posterior_predictive'>\n",
+       "Group: /posterior_predictive\n",
+       "    Dimensions:                   (chain: 1, draw: 500, batch: 2, time: 100,\n",
+       "                                   state: 3, observed_state: 1)\n",
+       "    Coordinates:\n",
+       "      * chain                     (chain) int64 8B 0\n",
+       "      * draw                      (draw) int64 4kB 0 1 2 3 4 ... 495 496 497 498 499\n",
+       "      * batch                     (batch) <U7 56B 'batch_1' 'batch_2'\n",
+       "      * time                      (time) int64 800B 0 1 2 3 4 5 ... 95 96 97 98 99\n",
+       "      * state                     (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "      * observed_state            (observed_state) <U3 12B 'ts1'\n",
+       "    Data variables:\n",
+       "        filtered_prior            (chain, draw, batch, time, state) float64 2MB -...\n",
+       "        filtered_prior_observed   (chain, draw, batch, time, observed_state) float64 800kB ...\n",
+       "        predicted_prior           (chain, draw, batch, time, state) float64 2MB -...\n",
+       "        predicted_prior_observed  (chain, draw, batch, time, observed_state) float64 800kB ...\n",
+       "        smoothed_prior            (chain, draw, batch, time, state) float64 2MB -...\n",
+       "        smoothed_prior_observed   (chain, draw, batch, time, observed_state) float64 800kB ...\n",
+       "    Attributes:\n",
+       "        created_at:                 2026-05-29T00:27:23.480095+00:00\n",
+       "        creation_library:           ArviZ\n",
+       "        creation_library_version:   1.1.0\n",
+       "        creation_library_language:  Python\n",
+       "        inference_library:          pymc\n",
+       "        inference_library_version:  6.0.0\n",
+       "        sample_dims:                ['chain', 'draw']
" + ], + "text/plain": [ + "\n", + "Group: /posterior_predictive\n", + " Dimensions: (chain: 1, draw: 500, batch: 2, time: 100,\n", + " state: 3, observed_state: 1)\n", + " Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 4kB 0 1 2 3 4 ... 495 496 497 498 499\n", + " * batch (batch) \n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree 'posterior_predictive'>\n",
+       "Group: /posterior_predictive\n",
+       "    Dimensions:         (chain: 1, draw: 500, time: 100, batch: 2, state: 3,\n",
+       "                         observed_state: 1)\n",
+       "    Coordinates:\n",
+       "      * chain           (chain) int64 8B 0\n",
+       "      * draw            (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
+       "      * time            (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n",
+       "      * batch           (batch) <U7 56B 'batch_1' 'batch_2'\n",
+       "      * state           (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "      * observed_state  (observed_state) <U3 12B 'ts1'\n",
+       "    Data variables:\n",
+       "        prior_latent    (chain, draw, time, batch, state) float64 2MB -0.3617 ......\n",
+       "        prior_observed  (chain, draw, time, batch, observed_state) float64 800kB ...\n",
+       "    Attributes:\n",
+       "        created_at:                 2026-05-29T00:27:25.728592+00:00\n",
+       "        creation_library:           ArviZ\n",
+       "        creation_library_version:   1.1.0\n",
+       "        creation_library_language:  Python\n",
+       "        inference_library:          pymc\n",
+       "        inference_library_version:  6.0.0\n",
+       "        sample_dims:                ['chain', 'draw']
" + ], + "text/plain": [ + "\n", + "Group: /posterior_predictive\n", + " Dimensions: (chain: 1, draw: 500, time: 100, batch: 2, state: 3,\n", + " observed_state: 1)\n", + " Coordinates:\n", + " * chain (chain) int64 8B 0\n", + " * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", + " * time (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n", + " * batch (batch) \n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "unpost = ar3.sample_unconditional_posterior(idata, mvn_method=\"cholesky\")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "b112c551", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataTree 'posterior_predictive'>\n",
+       "Group: /posterior_predictive\n",
+       "    Dimensions:             (chain: 4, draw: 200, time: 100, batch: 2, state: 3,\n",
+       "                             observed_state: 1)\n",
+       "    Coordinates:\n",
+       "      * chain               (chain) int64 32B 0 1 2 3\n",
+       "      * draw                (draw) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199\n",
+       "      * time                (time) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99\n",
+       "      * batch               (batch) <U7 56B 'batch_1' 'batch_2'\n",
+       "      * state               (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
+       "      * observed_state      (observed_state) <U3 12B 'ts1'\n",
+       "    Data variables:\n",
+       "        posterior_latent    (chain, draw, time, batch, state) float64 4MB -1.517 ...\n",
+       "        posterior_observed  (chain, draw, time, batch, observed_state) float64 1MB ...\n",
+       "    Attributes:\n",
+       "        created_at:                 2026-05-29T00:27:27.765195+00:00\n",
+       "        creation_library:           ArviZ\n",
+       "        creation_library_version:   1.1.0\n",
+       "        creation_library_language:  Python\n",
+       "        inference_library:          pymc\n",
+       "        inference_library_version:  6.0.0\n",
+       "        sample_dims:                ['chain', 'draw']
" + ], + "text/plain": [ + "\n", + "Group: /posterior_predictive\n", + " Dimensions: (chain: 4, draw: 200, time: 100, batch: 2, state: 3,\n", + " observed_state: 1)\n", + " Coordinates:\n", + " * chain (chain) int64 32B 0 1 2 3\n", + " * draw (draw) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199\n", + " * time (time) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99\n", + " * batch (batch) tuple[int | None, ...]: + inferred_batch_dims = () + + for core_matrix, sub_matrix in zip(core_matrices, subbed_matrices): + core_shape = core_matrix.type.shape + sub_shape = sub_matrix.type.shape + + if len(sub_shape) < len(core_shape): + raise ValueError( + f"Subbed matrix has fewer dims than core matrix: " f"{sub_shape} vs {core_shape}" + ) + + # Verify trailing/core dimensions match + trailing_shape = sub_shape[-len(core_shape) :] + + for core_dim, sub_dim in zip(core_shape, trailing_shape): + if core_dim is not None and sub_dim is not None and core_dim != sub_dim: + raise ValueError(f"Core dimension mismatch: " f"{core_shape} vs {sub_shape}") + + batch_dims = sub_shape[: -len(core_shape)] + + # Skip matrices with no batch dimensions + if len(batch_dims) == 0: + continue + + # First batched tensor establishes the batch shape + if len(inferred_batch_dims) == 0: + inferred_batch_dims = batch_dims + continue + + # Validate consistency + if len(batch_dims) != len(inferred_batch_dims): + raise ValueError(f"Inconsistent batch rank: " f"{batch_dims} vs {inferred_batch_dims}") + + merged_dims = [] + + for inferred_dim, new_dim in zip(inferred_batch_dims, batch_dims): + if inferred_dim is None: + merged_dims.append(new_dim) + elif new_dim is None: + merged_dims.append(inferred_dim) + elif inferred_dim == new_dim: + merged_dims.append(inferred_dim) + else: + raise ValueError( + f"Inconsistent batch dimensions: " f"{batch_dims} vs {inferred_batch_dims}" + ) + + inferred_batch_dims = tuple(merged_dims) + + return inferred_batch_dims + + class PyMCStateSpace: r""" Base class for Linear Gaussian Statespace models in PyMC. @@ -260,7 +316,6 @@ def __init__( k_endog: int, k_states: int, k_posdef: int, - batch_size: int | None = None, filter_type: str = "standard", verbose: bool = True, measurement_error: bool = False, @@ -282,8 +337,6 @@ def __init__( self.measurement_error = measurement_error self.mode = mode - self.batch_size = (batch_size,) if batch_size is not None else () - self._populate_properties() # Placeholder for time-varying matrices that depend on data length @@ -878,16 +931,8 @@ def _insert_random_variables(self): matrices = list(self._unpack_statespace_with_placeholders()) - if self.batch_size: - replacement_dict = { - var: pymc_model[name] for name, var in self._name_to_variable.items() - } - self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict) - else: - replacement_dict = { - var: pymc_model[name] for name, var in self._name_to_variable.items() - } - self.subbed_ssm = graph_replace(matrices, replace=replacement_dict) + replacement_dict = {var: pymc_model[name] for name, var in self._name_to_variable.items()} + self.subbed_ssm = vectorize_graph(matrices, replace=replacement_dict) def _insert_data_variables(self): """ @@ -1224,6 +1269,11 @@ def build_statespace_graph( self._save_exogenous_data_info() self._insert_data_variables() + self.batch_size = _infer_batch_dimensions( + core_matrices=self._unpack_statespace_with_placeholders(), + subbed_matrices=self.unpack_statespace(), + ) + obs_coords = pm_mod.coords.get(OBS_STATE_DIM, None) self._fit_data = data From 3410b2197291572d93b947fd5f483fdd6120fc98 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sun, 7 Jun 2026 16:43:48 -0600 Subject: [PATCH 09/10] 1. fixed a bug with sample_filter_outputs() when vector_graph doesn't vectorize certain matrices 2. Updated observation covariance computation in sample_conditional to support batched dimensions 3. added tests for batched dimension SSMs --- .../temporary_scratchpad_ssm_batch_dims.ipynb | 5412 +---------------- pymc_extras/statespace/core/statespace.py | 6 +- tests/statespace/core/test_statespace.py | 284 + 3 files changed, 356 insertions(+), 5346 deletions(-) diff --git a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb index 0e56ea56b..4d27bcb6d 100644 --- a/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb +++ b/notebooks/temporary_scratchpad_ssm_batch_dims.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d95e629f", "metadata": {}, "outputs": [], @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "34b32aa3", "metadata": {}, "outputs": [], @@ -55,13 +55,11 @@ " k_states = 3 # size of the state vector x\n", " k_posdef = 1 # number of shocks (size of the state covariance matrix Q)\n", " k_endog = 1 # number of observed states\n", - " # batch_size = 2\n", "\n", " super().__init__(\n", " k_endog=k_endog,\n", " k_states=k_states,\n", " k_posdef=k_posdef,\n", - " # batch_size=batch_size,\n", " mode=mode,\n", " )\n", "\n", @@ -123,83 +121,50 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "bfae06f5", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
                          Model Requirements                           \n",
-       "                                                                       \n",
-       "  Variable    Shape    Constraints                         Dimensions  \n",
-       " ───────────────────────────────────────────────────────────────────── \n",
-       "  x0          (3,)                                         ('state',)  \n",
-       "  P0          (3, 3)                           ('state', 'state_aux')  \n",
-       "  ar_params   (3,)     Stationary, please :)             ('ar_lags',)  \n",
-       "  sigma_x     (1,)                                         ('shock',)  \n",
-       "                                                                       \n",
-       " These parameters should be assigned priors inside a PyMC model block  \n",
-       "           before calling the build_statespace_graph method.           \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[3m Model Requirements \u001b[0m\n", - " \n", - " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", - " ───────────────────────────────────────────────────────────────────── \n", - " x0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m,\u001b[1m)\u001b[0m \n", - " P0 \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", - " ar_params \u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m,\u001b[1m)\u001b[0m Stationary, please :\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'ar_lags'\u001b[0m,\u001b[1m)\u001b[0m \n", - " sigma_x \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'shock'\u001b[0m,\u001b[1m)\u001b[0m \n", - " \n", - "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block \u001b[0m\n", - "\u001b[2;3m before calling the build_statespace_graph method. \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "ar3 = AutoRegressiveThree(mode=\"NUMBA\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "id": "27119d27", "metadata": {}, "outputs": [], "source": [ - "data = np.random.normal(0, 1, size=(100, 2))" + "data = np.random.normal(0, 1, size=(100, 10))" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "d9163bda", "metadata": {}, "outputs": [], "source": [ - "batched_data = data.reshape(2, 100, 1)" + "batched_data = data.reshape(10, 100, 1)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, + "id": "8ed64433", + "metadata": {}, + "outputs": [], + "source": [ + "batched_data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, "id": "217b812d", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", - " warnings.warn(NO_TIME_INDEX_WARNING)\n" - ] - } - ], + "outputs": [], "source": [ "# Not vectorized\n", "with pm.Model(coords=ar3.coords) as pymc_mod:\n", @@ -214,25 +179,32 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "84ac6e46", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", - " warnings.warn(NO_TIME_INDEX_WARNING)\n" - ] - } - ], + "outputs": [], "source": [ "# Vectorized\n", - "with pm.Model(coords=ar3.coords | {\"batch\": [\"batch_1\", \"batch_2\"]}) as pymc_mod:\n", - " x0 = pm.Deterministic(\"x0\", pt.zeros((2, 3)), dims=(\"batch\", \"state\"))\n", + "with pm.Model(\n", + " coords=ar3.coords\n", + " | {\n", + " \"batch\": [\n", + " \"batch_1\",\n", + " \"batch_2\",\n", + " \"batch_3\",\n", + " \"batch_4\",\n", + " \"batch_5\",\n", + " \"batch_6\",\n", + " \"batch_7\",\n", + " \"batch_8\",\n", + " \"batch_9\",\n", + " \"batch_10\",\n", + " ]\n", + " }\n", + ") as pymc_mod:\n", + " x0 = pm.Deterministic(\"x0\", pt.zeros((10, 3)), dims=(\"batch\", \"state\"))\n", " P0 = pm.Deterministic(\n", - " \"P0\", pt.tile(pt.eye(3) * 1, (2, 1, 1)), dims=(\"batch\", \"state\", \"state_aux\")\n", + " \"P0\", pt.tile(pt.eye(3) * 1, (10, 1, 1)), dims=(\"batch\", \"state\", \"state_aux\")\n", " )\n", " ar_params = pm.Normal(\"ar_params\", dims=(\"batch\", \"state\"))\n", "\n", @@ -244,3005 +216,90 @@ { "cell_type": "code", "execution_count": null, - "id": "55b1b63c", + "id": "ed208bcd", "metadata": {}, "outputs": [], "source": [ - "matrices = ar3.unpack_statespace()" + "with pymc_mod:\n", + " idata = pm.sample(tune=20, draws=20, compile_kwargs={\"mode\": \"NUMBA\"})" ] }, { "cell_type": "code", "execution_count": null, - "id": "6bc22797", + "id": "c172bc3a", "metadata": {}, "outputs": [], "source": [ - "matrices_2 = ar3._unpack_statespace_with_placeholders()" + "ar3.impulse_response_function(idata)" ] }, { "cell_type": "code", "execution_count": null, - "id": "c9dbac19", + "id": "674ae03e", "metadata": {}, "outputs": [], "source": [ - "[(m.name, m.type.shape) for m in matrices_2]" + "ar3.ssm[\"design\"].eval().shape" ] }, { "cell_type": "code", "execution_count": null, - "id": "e3aa195b", + "id": "0fd9532e", "metadata": {}, "outputs": [], "source": [ - "[(m.name, m.type.shape) for m in matrices]" + "pt.tile(ar3.ssm[\"design\"], (3, 1)).eval().shape" ] }, { "cell_type": "code", "execution_count": null, - "id": "0e4e79de", + "id": "24035cf0", "metadata": {}, "outputs": [], "source": [ - "type(matrices[0])" + "ar3.sample_statespace_matrices(idata, matrix_names=[\"Z\"]).posterior_predictive" ] }, { "cell_type": "code", "execution_count": null, - "id": "d302f051", + "id": "64940705", "metadata": {}, "outputs": [], "source": [ - "def infer_batch_dimensions(\n", - " core_matrices: list[pt.TensorVariable],\n", - " subbed_matrices: list[pt.TensorVariable],\n", - ") -> tuple[int | None, ...]:\n", - " inferred_batch_dims = ()\n", - "\n", - " for core_matrix, sub_matrix in zip(core_matrices, subbed_matrices):\n", - " core_shape = core_matrix.type.shape\n", - " sub_shape = sub_matrix.type.shape\n", - "\n", - " if len(sub_shape) < len(core_shape):\n", - " raise ValueError(\n", - " f\"Subbed matrix has fewer dims than core matrix: \" f\"{sub_shape} vs {core_shape}\"\n", - " )\n", - "\n", - " # Verify trailing/core dimensions match\n", - " trailing_shape = sub_shape[-len(core_shape) :]\n", - "\n", - " for core_dim, sub_dim in zip(core_shape, trailing_shape):\n", - " if core_dim is not None and sub_dim is not None and core_dim != sub_dim:\n", - " raise ValueError(f\"Core dimension mismatch: \" f\"{core_shape} vs {sub_shape}\")\n", - "\n", - " batch_dims = sub_shape[: -len(core_shape)]\n", - "\n", - " # Skip matrices with no batch dimensions\n", - " if len(batch_dims) == 0:\n", - " continue\n", - "\n", - " # First batched tensor establishes the batch shape\n", - " if len(inferred_batch_dims) == 0:\n", - " inferred_batch_dims = batch_dims\n", - " continue\n", - "\n", - " # Validate consistency\n", - " if len(batch_dims) != len(inferred_batch_dims):\n", - " raise ValueError(f\"Inconsistent batch rank: \" f\"{batch_dims} vs {inferred_batch_dims}\")\n", - "\n", - " merged_dims = []\n", - "\n", - " for inferred_dim, new_dim in zip(inferred_batch_dims, batch_dims):\n", - " if inferred_dim is None:\n", - " merged_dims.append(new_dim)\n", - " elif new_dim is None:\n", - " merged_dims.append(inferred_dim)\n", - " elif inferred_dim == new_dim:\n", - " merged_dims.append(inferred_dim)\n", - " else:\n", - " raise ValueError(\n", - " f\"Inconsistent batch dimensions: \" f\"{batch_dims} vs {inferred_batch_dims}\"\n", - " )\n", - "\n", - " inferred_batch_dims = tuple(merged_dims)\n", - "\n", - " return inferred_batch_dims" + "ar3.sample_filter_outputs(idata, filter_output_names=[\"smoothed_covariances\"])" ] }, { "cell_type": "code", "execution_count": null, - "id": "257c1a2b", - "metadata": {}, - "outputs": [], - "source": [ - "infer_batch_dimensions(core_matrices=matrices_2, subbed_matrices=matrices)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "ed208bcd", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Initializing NUTS using jitter+adapt_diag...\n", - "Multiprocess sampling (4 chains in 4 jobs)\n", - "NUTS: [ar_params, sigma_x]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f844cca998fe40b59f9a78d011c634c9", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Sampling 4 chains for 200 tune and 200 draw iterations (800 + 800 draws total) took 6 seconds.\n"
-     ]
-    }
-   ],
-   "source": [
-    "with pymc_mod:\n",
-    "    idata = pm.sample(tune=200, draws=200, compile_kwargs={\"mode\": \"NUMBA\"})"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "id": "c172bc3a",
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "Sampling: [initial_shock]\n"
-     ]
-    },
-    {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "68b3dc4111e5462596fbffec03183126",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Output()"
-      ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree 'posterior_predictive'>\n",
-       "Group: /posterior_predictive\n",
-       "    Dimensions:  (chain: 4, draw: 200, batch: 2, time: 40, state: 3)\n",
-       "    Coordinates:\n",
-       "      * chain    (chain) int64 32B 0 1 2 3\n",
-       "      * draw     (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
-       "      * batch    (batch) <U7 56B 'batch_1' 'batch_2'\n",
-       "      * time     (time) int64 320B 0 1 2 3 4 5 6 7 8 ... 31 32 33 34 35 36 37 38 39\n",
-       "      * state    (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "    Data variables:\n",
-       "        irf      (chain, draw, batch, time, state) float64 2MB 2.659 ... 4.025e-13\n",
-       "    Attributes:\n",
-       "        created_at:                 2026-05-29T00:26:42.616309+00:00\n",
-       "        creation_library:           ArviZ\n",
-       "        creation_library_version:   1.1.0\n",
-       "        creation_library_language:  Python\n",
-       "        inference_library:          pymc\n",
-       "        inference_library_version:  6.0.0\n",
-       "        sample_dims:                ['chain', 'draw']
" - ], - "text/plain": [ - "\n", - "Group: /posterior_predictive\n", - " Dimensions: (chain: 4, draw: 200, batch: 2, time: 40, state: 3)\n", - " Coordinates:\n", - " * chain (chain) int64 32B 0 1 2 3\n", - " * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n", - " * batch (batch) \n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree>\n",
-       "Group: /\n",
-       "├── Group: /posterior_predictive\n",
-       "│       Dimensions:    (chain: 4, draw: 200, batch: 2, state: 3, state_aux: 3)\n",
-       "│       Coordinates:\n",
-       "│         * chain      (chain) int64 32B 0 1 2 3\n",
-       "│         * draw       (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n",
-       "│         * batch      (batch) <U7 56B 'batch_1' 'batch_2'\n",
-       "│         * state      (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "│         * state_aux  (state_aux) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "│       Data variables:\n",
-       "│           T          (chain, draw, batch, state, state_aux) float64 115kB 0.06838 ....\n",
-       "│       Attributes:\n",
-       "│           created_at:                 2026-05-29T00:26:44.280981+00:00\n",
-       "│           creation_library:           ArviZ\n",
-       "│           creation_library_version:   1.1.0\n",
-       "│           creation_library_language:  Python\n",
-       "│           inference_library:          pymc\n",
-       "│           inference_library_version:  6.0.0\n",
-       "│           sample_dims:                ['chain', 'draw']\n",
-       "└── Group: /observed_data\n",
-       "        Attributes:\n",
-       "            created_at:                 2026-05-29T00:26:44.282187+00:00\n",
-       "            creation_library:           ArviZ\n",
-       "            creation_library_version:   1.1.0\n",
-       "            creation_library_language:  Python\n",
-       "            inference_library:          pymc\n",
-       "            inference_library_version:  6.0.0\n",
-       "            sample_dims:                []
" - ], - "text/plain": [ - "\n", - "Group: /\n", - "├── Group: /posterior_predictive\n", - "│ Dimensions: (chain: 4, draw: 200, batch: 2, state: 3, state_aux: 3)\n", - "│ Coordinates:\n", - "│ * chain (chain) int64 32B 0 1 2 3\n", - "│ * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 193 194 195 196 197 198 199\n", - "│ * batch (batch) \n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree>\n",
-       "Group: /\n",
-       "├── Group: /posterior_predictive\n",
-       "│       Dimensions:               (chain: 4, draw: 200, batch: 2, time: 100, state: 3,\n",
-       "│                                  state_aux: 3)\n",
-       "│       Coordinates:\n",
-       "│         * chain                 (chain) int64 32B 0 1 2 3\n",
-       "│         * draw                  (draw) int64 2kB 0 1 2 3 4 5 ... 195 196 197 198 199\n",
-       "│         * batch                 (batch) int64 16B 0 1\n",
-       "│         * time                  (time) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99\n",
-       "│         * state                 (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "│         * state_aux             (state_aux) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "│       Data variables:\n",
-       "│           filtered_covariances  (chain, draw, batch, time, state, state_aux) float64 12MB ...\n",
-       "│       Attributes:\n",
-       "│           created_at:                 2026-05-29T00:26:48.352615+00:00\n",
-       "│           creation_library:           ArviZ\n",
-       "│           creation_library_version:   1.1.0\n",
-       "│           creation_library_language:  Python\n",
-       "│           inference_library:          pymc\n",
-       "│           inference_library_version:  6.0.0\n",
-       "│           sample_dims:                ['chain', 'draw']\n",
-       "├── Group: /observed_data\n",
-       "│       Attributes:\n",
-       "│           created_at:                 2026-05-29T00:26:48.353653+00:00\n",
-       "│           creation_library:           ArviZ\n",
-       "│           creation_library_version:   1.1.0\n",
-       "│           creation_library_language:  Python\n",
-       "│           inference_library:          pymc\n",
-       "│           inference_library_version:  6.0.0\n",
-       "│           sample_dims:                []\n",
-       "└── Group: /constant_data\n",
-       "        Dimensions:         (batch: 2, time: 100, observed_state: 1)\n",
-       "        Coordinates:\n",
-       "          * batch           (batch) int64 16B 0 1\n",
-       "          * time            (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n",
-       "          * observed_state  (observed_state) <U3 12B 'ts1'\n",
-       "        Data variables:\n",
-       "            data            (batch, time, observed_state) float64 2kB -0.9498 ... 0.3309\n",
-       "        Attributes:\n",
-       "            created_at:                 2026-05-29T00:26:48.354952+00:00\n",
-       "            creation_library:           ArviZ\n",
-       "            creation_library_version:   1.1.0\n",
-       "            creation_library_language:  Python\n",
-       "            inference_library:          pymc\n",
-       "            inference_library_version:  6.0.0\n",
-       "            sample_dims:                []
" - ], - "text/plain": [ - "\n", - "Group: /\n", - "├── Group: /posterior_predictive\n", - "│ Dimensions: (chain: 4, draw: 200, batch: 2, time: 100, state: 3,\n", - "│ state_aux: 3)\n", - "│ Coordinates:\n", - "│ * chain (chain) int64 32B 0 1 2 3\n", - "│ * draw (draw) int64 2kB 0 1 2 3 4 5 ... 195 196 197 198 199\n", - "│ * batch (batch) int64 16B 0 1\n", - "│ * time (time) int64 800B 0 1 2 3 4 5 6 ... 94 95 96 97 98 99\n", - "│ * state (state) \n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree 'posterior_predictive'>\n",
-       "Group: /posterior_predictive\n",
-       "    Dimensions:            (chain: 4, draw: 200, time: 10, batch: 2, state: 3,\n",
-       "                            observed_state: 1)\n",
-       "    Coordinates:\n",
-       "      * chain              (chain) int64 32B 0 1 2 3\n",
-       "      * draw               (draw) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199\n",
-       "      * time               (time) int64 80B 100 101 102 103 104 105 106 107 108 109\n",
-       "      * batch              (batch) <U7 56B 'batch_1' 'batch_2'\n",
-       "      * state              (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "      * observed_state     (observed_state) <U3 12B 'ts1'\n",
-       "    Data variables:\n",
-       "        forecast_latent    (chain, draw, time, batch, state) float64 384kB -1.179...\n",
-       "        forecast_observed  (chain, draw, time, batch, observed_state) float64 128kB ...\n",
-       "    Attributes:\n",
-       "        created_at:                 2026-05-29T00:26:56.955319+00:00\n",
-       "        creation_library:           ArviZ\n",
-       "        creation_library_version:   1.1.0\n",
-       "        creation_library_language:  Python\n",
-       "        inference_library:          pymc\n",
-       "        inference_library_version:  6.0.0\n",
-       "        sample_dims:                ['chain', 'draw']
" - ], - "text/plain": [ - "\n", - "Group: /posterior_predictive\n", - " Dimensions: (chain: 4, draw: 200, time: 10, batch: 2, state: 3,\n", - " observed_state: 1)\n", - " Coordinates:\n", - " * chain (chain) int64 32B 0 1 2 3\n", - " * draw (draw) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199\n", - " * time (time) int64 80B 100 101 102 103 104 105 106 107 108 109\n", - " * batch (batch) \n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "post = ar3.sample_conditional_posterior(idata, mvn_method=\"cholesky\")" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "5c3800a9", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Sampling: [ar_params, obs, sigma_x]\n" - ] - } - ], + "outputs": [], "source": [ "with pymc_mod:\n", " prior = pm.sample_prior_predictive(compile_kwargs={\"mode\": \"NUMBA\"})" @@ -3250,2367 +307,40 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "id": "5857a994", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/dekermanjian/Desktop/Open_Source_Contributions/pymc-extras/pymc_extras/statespace/utils/data_tools.py:78: UserWarning: No time index found on the supplied data. A simple range index will be automatically generated.\n", - " warnings.warn(NO_TIME_INDEX_WARNING)\n", - "Sampling: [filtered_prior, filtered_prior_observed, predicted_prior, predicted_prior_observed, smoothed_prior, smoothed_prior_observed]\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "69a5f40528374795b2144353ba1f7dfe", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
-    },
-    {
-     "data": {
-      "text/html": [
-       "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree 'posterior_predictive'>\n",
-       "Group: /posterior_predictive\n",
-       "    Dimensions:                   (chain: 1, draw: 500, batch: 2, time: 100,\n",
-       "                                   state: 3, observed_state: 1)\n",
-       "    Coordinates:\n",
-       "      * chain                     (chain) int64 8B 0\n",
-       "      * draw                      (draw) int64 4kB 0 1 2 3 4 ... 495 496 497 498 499\n",
-       "      * batch                     (batch) <U7 56B 'batch_1' 'batch_2'\n",
-       "      * time                      (time) int64 800B 0 1 2 3 4 5 ... 95 96 97 98 99\n",
-       "      * state                     (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "      * observed_state            (observed_state) <U3 12B 'ts1'\n",
-       "    Data variables:\n",
-       "        filtered_prior            (chain, draw, batch, time, state) float64 2MB -...\n",
-       "        filtered_prior_observed   (chain, draw, batch, time, observed_state) float64 800kB ...\n",
-       "        predicted_prior           (chain, draw, batch, time, state) float64 2MB -...\n",
-       "        predicted_prior_observed  (chain, draw, batch, time, observed_state) float64 800kB ...\n",
-       "        smoothed_prior            (chain, draw, batch, time, state) float64 2MB -...\n",
-       "        smoothed_prior_observed   (chain, draw, batch, time, observed_state) float64 800kB ...\n",
-       "    Attributes:\n",
-       "        created_at:                 2026-05-29T00:27:23.480095+00:00\n",
-       "        creation_library:           ArviZ\n",
-       "        creation_library_version:   1.1.0\n",
-       "        creation_library_language:  Python\n",
-       "        inference_library:          pymc\n",
-       "        inference_library_version:  6.0.0\n",
-       "        sample_dims:                ['chain', 'draw']
" - ], - "text/plain": [ - "\n", - "Group: /posterior_predictive\n", - " Dimensions: (chain: 1, draw: 500, batch: 2, time: 100,\n", - " state: 3, observed_state: 1)\n", - " Coordinates:\n", - " * chain (chain) int64 8B 0\n", - " * draw (draw) int64 4kB 0 1 2 3 4 ... 495 496 497 498 499\n", - " * batch (batch) \n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree 'posterior_predictive'>\n",
-       "Group: /posterior_predictive\n",
-       "    Dimensions:         (chain: 1, draw: 500, time: 100, batch: 2, state: 3,\n",
-       "                         observed_state: 1)\n",
-       "    Coordinates:\n",
-       "      * chain           (chain) int64 8B 0\n",
-       "      * draw            (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n",
-       "      * time            (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n",
-       "      * batch           (batch) <U7 56B 'batch_1' 'batch_2'\n",
-       "      * state           (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "      * observed_state  (observed_state) <U3 12B 'ts1'\n",
-       "    Data variables:\n",
-       "        prior_latent    (chain, draw, time, batch, state) float64 2MB -0.3617 ......\n",
-       "        prior_observed  (chain, draw, time, batch, observed_state) float64 800kB ...\n",
-       "    Attributes:\n",
-       "        created_at:                 2026-05-29T00:27:25.728592+00:00\n",
-       "        creation_library:           ArviZ\n",
-       "        creation_library_version:   1.1.0\n",
-       "        creation_library_language:  Python\n",
-       "        inference_library:          pymc\n",
-       "        inference_library_version:  6.0.0\n",
-       "        sample_dims:                ['chain', 'draw']
" - ], - "text/plain": [ - "\n", - "Group: /posterior_predictive\n", - " Dimensions: (chain: 1, draw: 500, time: 100, batch: 2, state: 3,\n", - " observed_state: 1)\n", - " Coordinates:\n", - " * chain (chain) int64 8B 0\n", - " * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499\n", - " * time (time) int64 800B 0 1 2 3 4 5 6 7 ... 93 94 95 96 97 98 99\n", - " * batch (batch) \n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "unpost = ar3.sample_unconditional_posterior(idata, mvn_method=\"cholesky\")" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "b112c551", "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.DataTree 'posterior_predictive'>\n",
-       "Group: /posterior_predictive\n",
-       "    Dimensions:             (chain: 4, draw: 200, time: 100, batch: 2, state: 3,\n",
-       "                             observed_state: 1)\n",
-       "    Coordinates:\n",
-       "      * chain               (chain) int64 32B 0 1 2 3\n",
-       "      * draw                (draw) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199\n",
-       "      * time                (time) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99\n",
-       "      * batch               (batch) <U7 56B 'batch_1' 'batch_2'\n",
-       "      * state               (state) <U7 84B 'L1.data' 'L2.data' 'L3.data'\n",
-       "      * observed_state      (observed_state) <U3 12B 'ts1'\n",
-       "    Data variables:\n",
-       "        posterior_latent    (chain, draw, time, batch, state) float64 4MB -1.517 ...\n",
-       "        posterior_observed  (chain, draw, time, batch, observed_state) float64 1MB ...\n",
-       "    Attributes:\n",
-       "        created_at:                 2026-05-29T00:27:27.765195+00:00\n",
-       "        creation_library:           ArviZ\n",
-       "        creation_library_version:   1.1.0\n",
-       "        creation_library_language:  Python\n",
-       "        inference_library:          pymc\n",
-       "        inference_library_version:  6.0.0\n",
-       "        sample_dims:                ['chain', 'draw']
" - ], - "text/plain": [ - "\n", - "Group: /posterior_predictive\n", - " Dimensions: (chain: 4, draw: 200, time: 100, batch: 2, state: 3,\n", - " observed_state: 1)\n", - " Coordinates:\n", - " * chain (chain) int64 32B 0 1 2 3\n", - " * draw (draw) int64 2kB 0 1 2 3 4 5 ... 194 195 196 197 198 199\n", - " * time (time) int64 800B 0 1 2 3 4 5 6 ... 93 94 95 96 97 98 99\n", - " * batch (batch) pm.Model: + mod = st.LevelTrend(order=2, innovations_order=[0, 1]) + mod += st.Autoregressive(name="ar", order=1) + mod += st.MeasurementError(name="obs") + ssm = mod.build(name="batch_ssm", mode="NUMBA") + + if len(batch_dims) > 1: + raise NotImplementedError + + batched_data = np.random.normal(0, 1, size=(*batch_dims, 100, 1)) + + with pm.Model( + coords=ssm.coords | {"batch": [f"batch_{i}" for i in range(batch_dims[0])]} + ) as pymc_mod: + P0 = pm.Deterministic( + "P0", pt.tile(pt.eye(3) * 1, (*batch_dims, 1, 1)), dims=("batch", "state", "state_aux") + ) + + initial_level_trend = pm.Normal("initial_level_trend", dims=("batch", "state_level_trend")) + params_ar = pm.Beta("params_ar", alpha=3, beta=3, dims=("batch", "lag_ar")) + + sigma_level_trend = pm.Gamma( + "sigma_level_trend", alpha=2, beta=50, dims=("batch", "shock_level_trend") + ) + sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=5, shape=(*batch_dims,)) + sigma_obs = pm.HalfNormal("sigma_obs", sigma=0.05, shape=(*batch_dims,)) + + ssm.build_statespace_graph(data=batched_data) + + coord_dimensions = {k: len(v) for k, v in pymc_mod.coords.items()} + + with pymc_mod: + prior = pm.sample_prior_predictive(compile_kwargs={"mode": "NUMBA"}, random_seed=rng) + idata = pm.sample(tune=10, draws=10, compile_kwargs={"mode": "NUMBA"}, random_seed=rng) + + unconditional_prior = ssm.sample_unconditional_prior( + prior, mvn_method="cholesky", random_seed=rng + ) + conditional_prior = ssm.sample_conditional_prior(prior, mvn_method="cholesky", random_seed=rng) + + prior_chain = prior.prior.dims["chain"] + prior_draw = prior.prior.dims["draw"] + + assert prior.prior["params_ar"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["lag_ar"], + ) + assert prior.prior["sigma_ar"].shape == (prior_chain, prior_draw, coord_dimensions["batch"]) + assert prior.prior["initial_level_trend"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["state_level_trend"], + ) + assert prior.prior["sigma_level_trend"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["shock_level_trend"], + ) + assert prior.prior["sigma_obs"].shape == (prior_chain, prior_draw, coord_dimensions["batch"]) + assert prior.prior["P0"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["state"], + coord_dimensions["state_aux"], + ) + + assert conditional_prior["filtered_prior"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_prior["filtered_prior_observed"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + assert conditional_prior["predicted_prior"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_prior["predicted_prior_observed"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + assert conditional_prior["smoothed_prior"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_prior["smoothed_prior_observed"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + + assert unconditional_prior["prior_latent"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["time"], + coord_dimensions["batch"], + coord_dimensions["state"], + ) + assert unconditional_prior["prior_observed"].shape == ( + prior_chain, + prior_draw, + coord_dimensions["time"], + coord_dimensions["batch"], + coord_dimensions["observed_state"], + ) + + posterior_chain = idata.posterior.dims["chain"] + posterior_draw = idata.posterior.dims["draw"] + + assert idata.posterior["params_ar"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["lag_ar"], + ) + assert idata.posterior["sigma_ar"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + ) + assert idata.posterior["initial_level_trend"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["state_level_trend"], + ) + assert idata.posterior["sigma_level_trend"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["shock_level_trend"], + ) + assert idata.posterior["sigma_obs"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + ) + assert idata.posterior["P0"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["state"], + coord_dimensions["state_aux"], + ) + + irf = ssm.impulse_response_function(idata, random_seed=rng) + irf_steps = irf.dims["time"] + assert irf.irf.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + irf_steps, + coord_dimensions["state"], + ) + + T_sample = ssm.sample_statespace_matrices(idata, matrix_names=["T"]) + assert T_sample.posterior_predictive.T.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["state"], + coord_dimensions["state_aux"], + ) + + filtered_covariance_sample = ssm.sample_filter_outputs( + idata, filter_output_names=["filtered_covariances"] + ) + assert filtered_covariance_sample.posterior_predictive.filtered_covariances.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + coord_dimensions["state_aux"], + ) + + forecast = ssm.forecast(idata, periods=10, random_seed=rng) + forecast_steps = forecast.dims["time"] + assert forecast.forecast_latent.shape == ( + posterior_chain, + posterior_draw, + forecast_steps, + coord_dimensions["batch"], + coord_dimensions["state"], + ) + assert forecast.forecast_observed.shape == ( + posterior_chain, + posterior_draw, + forecast_steps, + coord_dimensions["batch"], + coord_dimensions["observed_state"], + ) + + unconditional_post = ssm.sample_unconditional_posterior( + idata, mvn_method="cholesky", random_seed=rng + ) + assert unconditional_post.posterior_latent.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["time"], + coord_dimensions["batch"], + coord_dimensions["state"], + ) + assert unconditional_post.posterior_observed.shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["time"], + coord_dimensions["batch"], + coord_dimensions["observed_state"], + ) + + conditional_post = ssm.sample_conditional_posterior( + idata, mvn_method="cholesky", random_seed=rng + ) + assert conditional_post["filtered_posterior"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_post["filtered_posterior_observed"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + assert conditional_post["predicted_posterior"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_post["predicted_posterior_observed"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) + assert conditional_post["smoothed_posterior"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["state"], + ) + assert conditional_post["smoothed_posterior_observed"].shape == ( + posterior_chain, + posterior_draw, + coord_dimensions["batch"], + coord_dimensions["time"], + coord_dimensions["observed_state"], + ) From 86916f0a8e80737a18661a52da97eed05e888217 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Thu, 11 Jun 2026 05:58:05 -0600 Subject: [PATCH 10/10] removed mock_sample from batch_ssm tests --- tests/statespace/core/test_statespace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/statespace/core/test_statespace.py b/tests/statespace/core/test_statespace.py index 14601108e..6e1210d9e 100644 --- a/tests/statespace/core/test_statespace.py +++ b/tests/statespace/core/test_statespace.py @@ -1401,7 +1401,7 @@ def test_impulse_response_function(self, ss_mod_time_varying, idata_time_varying @pytest.mark.filterwarnings("ignore:No time index found on the supplied data.") @pytest.mark.filterwarnings("ignore:No start date provided") @pytest.mark.parametrize("batch_dims", [(2,), (5,), (7,)]) -def test_batch_ssm(batch_dims: tuple[int], rng, mock_pymc_sample) -> pm.Model: +def test_batch_ssm(batch_dims: tuple[int], rng) -> pm.Model: mod = st.LevelTrend(order=2, innovations_order=[0, 1]) mod += st.Autoregressive(name="ar", order=1) mod += st.MeasurementError(name="obs")