Skip to content

Commit ec6a4bd

Browse files
committed
Align structural naming registry with prefixed state space names
1 parent 27a28ea commit ec6a4bd

2 files changed

Lines changed: 60 additions & 3 deletions

File tree

pymc_extras/statespace/models/structural/core.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class StructuralTimeSeries(PyMCStateSpace):
117117
def __init__(
118118
self,
119119
ssm: PytensorRepresentation,
120-
name: str,
120+
name: str | None,
121121
coords_info: CoordInfo,
122122
param_info: ParameterInfo,
123123
data_info: DataInfo,
@@ -184,10 +184,31 @@ def __init__(
184184
verbose=verbose,
185185
measurement_error=measurement_error,
186186
mode=mode,
187+
name=name,
187188
)
188189

189-
self._tensor_variable_info = tensor_variable_info
190-
self._tensor_data_info = tensor_data_info
190+
if name is None:
191+
self._tensor_variable_info = tensor_variable_info
192+
self._tensor_data_info = tensor_data_info
193+
else:
194+
self._tensor_variable_info = SymbolicVariableInfo(
195+
symbolic_variables=tuple(
196+
SymbolicVariable(
197+
name=self.prefixed_name(symbolic_variable.name),
198+
symbolic_variable=symbolic_variable.symbolic_variable,
199+
)
200+
for symbolic_variable in tensor_variable_info
201+
)
202+
)
203+
self._tensor_data_info = SymbolicDataInfo(
204+
symbolic_data=tuple(
205+
SymbolicData(
206+
name=self.prefixed_name(symbolic_data.name),
207+
symbolic_data=symbolic_data.symbolic_data,
208+
)
209+
for symbolic_data in tensor_data_info
210+
)
211+
)
191212
self._component_info = component_info.copy()
192213
self._exog_names = data_info.exogenous_names
193214
self._needs_exog_data = data_info.needs_exogenous_data

tests/statespace/models/structural/test_core.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,42 @@
1818
RTOL = 0 if floatX.endswith("64") else 1e-6
1919

2020

21+
def _build_named_structural_model(name: str):
22+
return (
23+
st.LevelTrend(order=1, innovations_order=1)
24+
+ st.Regression(name="reg", state_names=["x"])
25+
+ st.MeasurementError(name="obs")
26+
).build(name=name, verbose=False)
27+
28+
29+
def test_structural_name_propagates_to_base_and_scopes_p0():
30+
ss_mod = _build_named_structural_model(name="m1")
31+
32+
assert ss_mod.name == "m1"
33+
assert "P0" in ss_mod.param_names
34+
assert ss_mod.prefixed_name("P0") in ss_mod._name_to_variable
35+
assert "P0" not in ss_mod._name_to_variable
36+
37+
38+
def test_named_structural_models_do_not_collide_in_placeholder_registries():
39+
with pm.Model():
40+
m1 = _build_named_structural_model(name="m1")
41+
m2 = _build_named_structural_model(name="m2")
42+
43+
var_keys_1 = set(m1._name_to_variable)
44+
var_keys_2 = set(m2._name_to_variable)
45+
data_keys_1 = set(m1._name_to_data)
46+
data_keys_2 = set(m2._name_to_data)
47+
48+
assert var_keys_1.isdisjoint(var_keys_2)
49+
assert data_keys_1.isdisjoint(data_keys_2)
50+
51+
assert var_keys_1 == {m1.prefixed_name(name) for name in m1.param_names}
52+
assert var_keys_2 == {m2.prefixed_name(name) for name in m2.param_names}
53+
assert data_keys_1 == {m1.prefixed_name(name) for name in m1.data_names}
54+
assert data_keys_2 == {m2.prefixed_name(name) for name in m2.data_names}
55+
56+
2157
def test_add_components():
2258
ll = st.LevelTrend(order=2)
2359
se = st.TimeSeasonality(name="seasonal", season_length=12)

0 commit comments

Comments
 (0)