Skip to content

Commit 81c6f0e

Browse files
Fix bug in Statespace model state names (#634)
* Fix ETS states * Fix DFM states * Add workflow tests for SARIMAX and VARMAX
1 parent d8de1c6 commit 81c6f0e

6 files changed

Lines changed: 189 additions & 17 deletions

File tree

pymc_extras/statespace/models/DFM.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,12 @@ def set_states(self) -> State | tuple[State, ...] | None:
486486
for endog_name in self.endog_names
487487
)
488488

489-
return tuple(State(name=name, observed=False, shared=False) for name in names)
489+
hidden_states = [State(name=name, observed=False, shared=False) for name in names]
490+
observed_states = [
491+
State(name=name, observed=True, shared=False) for name in self.endog_names
492+
]
493+
494+
return *hidden_states, *observed_states
490495

491496
def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
492497
shock_names = [f"factor_shock_{i}" for i in range(self.k_factors)]
@@ -549,13 +554,6 @@ def set_coords(self) -> Coord | tuple[Coord, ...] | None:
549554

550555
return tuple(coords)
551556

552-
@property
553-
def observed_states(self) -> tuple[str, ...]:
554-
"""
555-
Returns the names of the observed states (i.e., the endogenous variables).
556-
"""
557-
return self.endog_names
558-
559557
def make_symbolic_graph(self):
560558
if not self.exog_flag:
561559
x0 = self.make_and_register_variable("x0", shape=(self.k_states,), dtype=floatX)

pymc_extras/statespace/models/ETS.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -446,15 +446,13 @@ def set_states(self) -> State | tuple[State, ...] | None:
446446
else:
447447
state_names = base_states
448448

449-
# First state for each endog is the innovation (observed), rest are hidden
450-
states = []
451-
states_per_endog = len(base_states)
452-
for i, name in enumerate(state_names):
453-
# innovation states are "observed" in the sense they directly affect the observation
454-
is_observed = (i % states_per_endog) == 0
455-
states.append(State(name=name, observed=is_observed, shared=False))
456-
457-
return tuple(states)
449+
hidden_states = [State(name=name, observed=False, shared=False) for name in state_names]
450+
451+
observed_states = [
452+
State(name=name, observed=True, shared=False) for name in self.endog_names
453+
]
454+
455+
return *hidden_states, *observed_states
458456

459457
def set_shocks(self) -> Shock | tuple[Shock, ...] | None:
460458
k_endog = self.k_endog

tests/statespace/models/test_DFM.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import statsmodels.api as sm
1010

1111
from numpy.testing import assert_allclose
12+
from pymc.testing import mock_sample_setup_and_teardown
1213
from pytensor.graph.traversal import explicit_graph_inputs
1314
from statsmodels.tsa.statespace.dynamic_factor import DynamicFactor
1415

@@ -28,6 +29,8 @@
2829
)
2930
from tests.statespace.shared_fixtures import rng
3031

32+
mock_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)
33+
3134
floatX = pytensor.config.floatX
3235

3336

@@ -718,3 +721,50 @@ def test_exog_not_shared_no_exog_innovations(self):
718721
assert len(mod.shock_names) == k_factors + k_endog + (
719722
k_exog if shared_exog_states else k_exog * k_endog
720723
)
724+
725+
726+
def test_dfm_workflow(rng, mock_sample):
727+
df = pd.read_csv(
728+
"tests/statespace/_data/statsmodels_macrodata_processed.csv",
729+
index_col=0,
730+
parse_dates=True,
731+
).astype(floatX)
732+
df.index.freq = df.index.inferred_freq
733+
734+
ss_mod = BayesianDynamicFactor(
735+
endog_names=df.columns.tolist(),
736+
k_factors=1,
737+
factor_order=1,
738+
error_order=0,
739+
measurement_error=True,
740+
verbose=False,
741+
)
742+
743+
with pm.Model(coords=ss_mod.coords) as m:
744+
pm.Normal("x0", dims=["state"])
745+
P0_diag = pm.Exponential("P0_diag", 1, dims=["state"])
746+
pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
747+
748+
pm.Normal("factor_loadings", dims=["observed_state", "factor"])
749+
pm.Normal("factor_ar", dims=["factor", "lag_ar"])
750+
pm.Exponential("error_sigma", 1, dims=["observed_state"])
751+
pm.Exponential("sigma_obs", 1, dims=["observed_state"])
752+
753+
ss_mod.build_statespace_graph(df)
754+
755+
idata = pm.sample()
756+
757+
post = ss_mod.sample_conditional_posterior(idata, mvn_method="svd")
758+
assert "filtered_posterior" in post
759+
assert "smoothed_posterior" in post
760+
assert "predicted_posterior" in post
761+
762+
forecast = ss_mod.forecast(idata, periods=10, random_seed=rng)
763+
assert "forecast_latent" in forecast
764+
assert "forecast_observed" in forecast
765+
assert np.isfinite(forecast.forecast_latent.values).all()
766+
assert np.isfinite(forecast.forecast_observed.values).all()
767+
768+
irf = ss_mod.impulse_response_function(idata, n_steps=10, random_seed=rng)
769+
assert "irf" in irf
770+
assert np.isfinite(irf.irf.values).all()

tests/statespace/models/test_ETS.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
import pymc as pm
23
import pytensor
34
import pytest
45
import statsmodels.api as sm
56

67
from numpy.testing import assert_allclose
8+
from pymc.testing import mock_sample_setup_and_teardown
79
from pytensor.graph.traversal import explicit_graph_inputs
810
from scipy import linalg
911

@@ -12,6 +14,8 @@
1214
from tests.statespace.shared_fixtures import rng
1315
from tests.statespace.test_utilities import load_nile_test_data
1416

17+
mock_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)
18+
1519

1620
@pytest.fixture(scope="session")
1721
def data():
@@ -419,3 +423,44 @@ def test_ETS_stationary_initialization():
419423
P0_expected = linalg.solve_discrete_lyapunov(T_stationary, R @ Q @ R.T)
420424

421425
assert_allclose(outputs["initial_state_cov"], P0_expected, rtol=1e-8, atol=1e-8)
426+
427+
428+
def test_ets_workflow(mock_sample):
429+
data = load_nile_test_data()
430+
431+
ss_mod = BayesianETS(
432+
order=("A", "Ad", "N"),
433+
endog_names=["height"],
434+
stationary_initialization=True,
435+
measurement_error=True,
436+
initialization_dampening=0.8,
437+
)
438+
439+
with pm.Model(coords=ss_mod.coords) as m:
440+
pm.Normal("initial_level", 0, 1)
441+
pm.Normal("initial_trend", 0, 1)
442+
pm.Beta("alpha", 1, 1)
443+
pm.Beta("beta", 1, 1)
444+
pm.Beta("phi", 1, 1)
445+
446+
pm.Exponential("sigma_state", 1)
447+
pm.Exponential("sigma_obs", 1)
448+
449+
ss_mod.build_statespace_graph(data)
450+
451+
idata = pm.sample()
452+
453+
post = ss_mod.sample_conditional_posterior(idata, mvn_method="cholesky")
454+
assert "filtered_posterior" in post
455+
assert "smoothed_posterior" in post
456+
assert "predicted_posterior" in post
457+
458+
forecast = ss_mod.forecast(idata, periods=10, random_seed=42)
459+
assert "forecast_latent" in forecast
460+
assert "forecast_observed" in forecast
461+
assert np.isfinite(forecast.forecast_latent.values).all()
462+
assert np.isfinite(forecast.forecast_observed.values).all()
463+
464+
irf = ss_mod.impulse_response_function(idata, n_steps=10, random_seed=42)
465+
assert "irf" in irf
466+
assert np.isfinite(irf.irf.values).all()

tests/statespace/models/test_SARIMAX.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,39 @@ def test_SARIMA_with_exogenous(rng, mock_sample):
469469
2,
470470
)
471471
np.testing.assert_allclose(ss_mod._fit_exog_data["exogenous_data"]["value"], data_val)
472+
473+
474+
def test_sarimax_workflow(mock_sample):
475+
data = load_nile_test_data()
476+
477+
ss_mod = BayesianSARIMAX(
478+
order=(1, 0, 1),
479+
stationary_initialization=True,
480+
measurement_error=True,
481+
verbose=False,
482+
)
483+
484+
with pm.Model(coords=ss_mod.coords) as m:
485+
pm.Normal("ar_params", dims=["lag_ar"])
486+
pm.Normal("ma_params", dims=["lag_ma"])
487+
pm.Exponential("sigma_state", 1)
488+
pm.Exponential("sigma_obs", 1)
489+
490+
ss_mod.build_statespace_graph(data)
491+
492+
idata = pm.sample()
493+
494+
post = ss_mod.sample_conditional_posterior(idata, mvn_method="svd")
495+
assert "filtered_posterior" in post
496+
assert "smoothed_posterior" in post
497+
assert "predicted_posterior" in post
498+
499+
forecast = ss_mod.forecast(idata, periods=10, random_seed=42)
500+
assert "forecast_latent" in forecast
501+
assert "forecast_observed" in forecast
502+
assert np.isfinite(forecast.forecast_latent.values).all()
503+
assert np.isfinite(forecast.forecast_observed.values).all()
504+
505+
irf = ss_mod.impulse_response_function(idata, n_steps=10, random_seed=42)
506+
assert "irf" in irf
507+
assert np.isfinite(irf.irf.values).all()

tests/statespace/models/test_VARMAX.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010

1111
from numpy.testing import assert_allclose, assert_array_less
1212
from pymc.model.transform.optimization import freeze_dims_and_data
13+
from pymc.testing import mock_sample_setup_and_teardown
1314

1415
from pymc_extras.statespace import BayesianVARMAX
1516
from pymc_extras.statespace.utils.constants import SHORT_NAME_TO_LONG
1617
from tests.statespace.shared_fixtures import ( # pylint: disable=unused-import
1718
rng,
1819
)
1920

21+
mock_sample = pytest.fixture(scope="function")(mock_sample_setup_and_teardown)
22+
2023
floatX = pytensor.config.floatX
2124
ps = [0, 1, 2, 3]
2225
qs = [0, 1, 2, 3]
@@ -200,6 +203,48 @@ def test_forecast(varma_mod, idata, rng):
200203
assert np.isfinite(forecast.forecast_observed.values).all()
201204

202205

206+
def test_varmax_workflow(rng, mock_sample):
207+
df = pd.read_csv(
208+
"tests/statespace/_data/statsmodels_macrodata_processed.csv",
209+
index_col=0,
210+
parse_dates=True,
211+
).astype(floatX)
212+
df.index.freq = df.index.inferred_freq
213+
214+
ss_mod = BayesianVARMAX(
215+
endog_names=df.columns,
216+
order=(1, 0),
217+
stationary_initialization=True,
218+
measurement_error=True,
219+
verbose=False,
220+
)
221+
222+
with pm.Model(coords=ss_mod.coords) as m:
223+
state_cov_diag = pm.Exponential("state_cov_diag", 1, dims=["shock"])
224+
pm.Deterministic("state_cov", pt.diag(state_cov_diag), dims=["shock", "shock_aux"])
225+
pm.Normal("ar_params", sigma=0.1, dims=["observed_state", "lag_ar", "observed_state_aux"])
226+
pm.Exponential("sigma_obs", 1, dims=["observed_state"])
227+
228+
ss_mod.build_statespace_graph(df)
229+
230+
idata = pm.sample()
231+
232+
post = ss_mod.sample_conditional_posterior(idata, mvn_method="svd")
233+
assert "filtered_posterior" in post
234+
assert "smoothed_posterior" in post
235+
assert "predicted_posterior" in post
236+
237+
forecast = ss_mod.forecast(idata, periods=10, random_seed=rng)
238+
assert "forecast_latent" in forecast
239+
assert "forecast_observed" in forecast
240+
assert np.isfinite(forecast.forecast_latent.values).all()
241+
assert np.isfinite(forecast.forecast_observed.values).all()
242+
243+
irf = ss_mod.impulse_response_function(idata, n_steps=10, random_seed=rng)
244+
assert "irf" in irf
245+
assert np.isfinite(irf.irf.values).all()
246+
247+
203248
class TestVARMAXWithExogenous:
204249
def test_create_varmax_with_exogenous_list_of_names(self, data):
205250
mod = BayesianVARMAX(

0 commit comments

Comments
 (0)