Skip to content

Commit a793417

Browse files
Refactor differenced processes, random walks, and AR processes (#380)
* Checkpoint commit on differencedar * Checkpoint commit on differencedar * Add de-differencing helper and tests * Replace fori_loop with scan * Fix bugs and typos in integrator, add test * Fix name of StandardNormalRandomWalk * Try two random walk impelmentations * Fix bug in differenced process * Fix some tests, add dynamic versus static DistributionalRVs * Split distributional RVs into static and dynamic * Update DistributionalRV kwarg dist => distribution in all tests * update dist => distribution kwarg in DistributionalRV in all tutorials * Add tests for DistributionalRV factory and classes * Allow scalars to play better with scan, give more informative error messages for incompatible arrays in certain scan functions * Autoformat files * Rename infection_initialization tests to be consistent with name of tested classes, refactor all tests to work with new process module, ensure that DifferencedProcess samples the fundamental process n-1 times to account for inits, use pytest.mark.parametrize in rw sample distribution test * Fix typo * Refactor AR and update tests * Refactor RtPeriodicDiff to use DifferencedProcess and ARProcess; delete unused manual AR function * Default names for now * Fundamental process init passthrough for differencedprocess, update tutorials * Fix typo in IIDRandomSequence class name, add test for class * Apply suggestions from code review Co-authored-by: Damon Bayer <xum8@cdc.gov> * Reintroduce test_rtperiodicdiff, removing manual reconstruction test that did not use the sampling method * Coerce to 1d in the appropriate place in infectionswithfeedback * Reintroduce padding * Coerce to array in tutorial * Fix tutorial bug * Update model/src/pyrenew/process/differencedprocess.py Co-authored-by: Damon Bayer <xum8@cdc.gov> * Restore padding in test_model_basic_renewal * Restore padding in test in a couple other places * Apply suggestions from code review revert to using gen_int.size() method Co-authored-by: Damon Bayer <xum8@cdc.gov> * Convert sds to scalars in test ar process * Rename PeriodicDiff classes to highlight that they are DiffAR processes * Move model up one level * Add expand method for distributional rvs * Mathtext and refs for differencedprocess class * Mathtext typo fixes and doc improvements for DifferencedProcess * More sphinx tweaks * More typos * Test for standard normal sequence * Float expectations * Add scipy to testing deps * Update pyproject for deptry * scipy allowed to be dev dep * Add test for vectorized sampling * Raise ValueError if noise sd not scalar * Fix expand_by() tests to use new syntax * Manual tests for integrator correctness, better documentation of output shape, revise sample method so that n is chosen properly * Fix distributional rv test * Raise eror for overly short DifferenceProcess samples; test for that error raise * Force differencedprocess to deal with 1D diffs and error otherwise, make it behave as expected for 0 < n < order * Fix scipy dep * Update test * Improve error message * Autoformat files * Update tutorials * More checks and clearer code for differencedprocess.py * Update src/pyrenew/process/differencedprocess.py Update link to Hyndman textbook Co-authored-by: Damon Bayer <xum8@cdc.gov> * Update src/pyrenew/process/differencedprocess.py Fix out of date inaccurate docstring Co-authored-by: Damon Bayer <xum8@cdc.gov> * Update src/pyrenew/process/differencedprocess.py Co-authored-by: Damon Bayer <xum8@cdc.gov> * Remove names for DifferencedProcess * Change refs to expected shape * One missing edit in test * Restore incorrectly removed test * Update random walk tests * Make AR strict about 1d input arrays * Configurable AR process name in rtperiodicdiffar class * Added check comparing to first initial value to small sample test * Remove unused function * Remove required names and adjust tests and structure accordingly * replace :fun: sphinx directives * Fix one more :fun: * And yet one more :fun: --------- Co-authored-by: Damon Bayer <xum8@cdc.gov>
1 parent 7ad81f1 commit a793417

38 files changed

Lines changed: 1500 additions & 708 deletions

docs/source/tutorials/basic_renewal_model.qmd

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import jax.numpy as jnp
1616
import numpy as np
1717
import numpyro
1818
import numpyro.distributions as dist
19-
from pyrenew.process import SimpleRandomWalkProcess
19+
from pyrenew.process import RandomWalk
2020
from pyrenew.latent import (
2121
Infections,
2222
InfectionInitializationProcess,
@@ -68,7 +68,7 @@ flowchart LR
6868
end
6969
7070
subgraph process[Process module]
71-
rt["Rt_process_rv\n(Custom class built using SimpleRandomWalk)"]
71+
rt["Rt_process_rv\n(Custom class built using RandomWalk)"]
7272
end
7373
7474
subgraph deterministic[Deterministic module]
@@ -139,26 +139,25 @@ class MyRt(RandomVariable):
139139
def validate(self):
140140
pass
141141
142-
def sample(self, n_steps: int, **kwargs) -> tuple:
142+
def sample(self, n: int, **kwargs) -> tuple:
143143
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
144144
145145
rt_rv = TransformedRandomVariable(
146-
"Rt_rv",
147-
base_rv=SimpleRandomWalkProcess(
146+
name="log_rt_random_walk",
147+
base_rv=RandomWalk(
148148
name="log_rt",
149149
step_rv=DistributionalRV(
150-
name="rw_step_rv",
151-
distribution=dist.Normal(0, sd_rt),
152-
reparam=LocScaleReparam(0),
153-
),
154-
init_rv=DistributionalRV(
155-
name="init_log_rt",
156-
distribution=dist.Normal(jnp.log(1), jnp.log(1.2)),
150+
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
157151
),
158152
),
159153
transforms=t.ExpTransform(),
160154
)
161-
return rt_rv.sample(n_steps=n_steps, **kwargs)
155+
rt_init_rv = DistributionalRV(
156+
name="init_log_rt", distribution=dist.Normal(0, 0.2)
157+
)
158+
init_rt, *_ = rt_init_rv.sample()
159+
160+
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
162161
163162
164163
rt_proc = MyRt()

docs/source/tutorials/day_of_the_week.qmd

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class MyRt(metaclass.RandomVariable):
106106
def validate(self):
107107
pass
108108
109-
def sample(self, n_steps: int, **kwargs) -> tuple:
109+
def sample(self, n: int, **kwargs) -> tuple:
110110
# Standard deviation of the random walk
111111
sd_rt, *_ = self.sd_rv()
112112
@@ -115,15 +115,14 @@ class MyRt(metaclass.RandomVariable):
115115
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
116116
)
117117
118-
init_rv = metaclass.DistributionalRV(
118+
rt_init_rv = metaclass.DistributionalRV(
119119
name="init_log_rt", distribution=dist.Normal(0, 0.2)
120120
)
121121
122122
# Random walk process
123-
base_rv = process.SimpleRandomWalkProcess(
123+
base_rv = process.RandomWalk(
124124
name="log_rt",
125125
step_rv=step_rv,
126-
init_rv=init_rv,
127126
)
128127
129128
# Transforming the random walk to the Rt scale
@@ -132,8 +131,9 @@ class MyRt(metaclass.RandomVariable):
132131
base_rv=base_rv,
133132
transforms=transformation.ExpTransform(),
134133
)
134+
init_rt, *_ = rt_init_rv.sample()
135135
136-
return rt_rv(n_steps=n_steps, **kwargs)
136+
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
137137
138138
139139
rtproc = MyRt(

docs/source/tutorials/day_of_the_week.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
.. Please do not edit this file directly.
33
.. This file is just a placeholder.
44
.. For the source file, see:
5-
.. <https://github.com/CDCgov/multisignal-epi-inference/tree/main/docs/source/tutorials/day_of_the_week.qmd>
5+
.. <https://github.com/CDCgov/PyRenew/tree/main/docs/source/tutorials/day_of_the_week.qmd>

docs/source/tutorials/extending_pyrenew.qmd

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ import numpyro.distributions as dist
2828
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
2929
from pyrenew.latent import InfectionsWithFeedback
3030
from pyrenew.model import RtInfectionsRenewalModel
31-
from pyrenew.process import SimpleRandomWalkProcess
32-
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
31+
from pyrenew.process import RandomWalk
32+
from pyrenew.metaclass import (
33+
RandomVariable,
34+
DistributionalRV,
35+
TransformedRandomVariable,
36+
)
3337
from pyrenew.latent import (
3438
InfectionInitializationProcess,
3539
InitializeInfectionsExponentialGrowth,
@@ -62,19 +66,31 @@ latent_infections = InfectionsWithFeedback(
6266
infection_feedback_pmf=gen_int,
6367
)
6468
65-
rt = TransformedRandomVariable(
66-
"Rt_rv",
67-
base_rv=SimpleRandomWalkProcess(
68-
name="log_rt",
69-
step_rv=DistributionalRV(
70-
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
71-
),
72-
init_rv=DistributionalRV(
69+
70+
class MyRt(RandomVariable):
71+
72+
def validate(self):
73+
pass
74+
75+
def sample(self, n: int, **kwargs) -> tuple:
76+
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
77+
78+
rt_rv = TransformedRandomVariable(
79+
name="log_rt_random_walk",
80+
base_rv=RandomWalk(
81+
name="log_rt",
82+
step_rv=DistributionalRV(
83+
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
84+
),
85+
),
86+
transforms=t.ExpTransform(),
87+
)
88+
rt_init_rv = DistributionalRV(
7389
name="init_log_rt", distribution=dist.Normal(0, 0.2)
74-
),
75-
),
76-
transforms=t.ExpTransform(),
77-
)
90+
)
91+
init_rt, *_ = rt_init_rv.sample()
92+
93+
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
7894
```
7995

8096
With all the components defined, we can build the model:
@@ -85,7 +101,7 @@ model0 = RtInfectionsRenewalModel(
85101
gen_int_rv=gen_int,
86102
I0_rv=I0,
87103
latent_infections_rv=latent_infections,
88-
Rt_process_rv=rt,
104+
Rt_process_rv=MyRt(),
89105
infection_obs_process_rv=None,
90106
)
91107
```
@@ -209,10 +225,13 @@ class InfFeedback(RandomVariable):
209225
inf_feedback_strength, *_ = self.infection_feedback_strength(
210226
**kwargs,
211227
)
228+
229+
inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength.value)
230+
212231
inf_feedback_strength = au.pad_x_to_match_y(
213-
x=inf_feedback_strength.value,
232+
x=inf_feedback_strength,
214233
y=Rt,
215-
fill_value=inf_feedback_strength.value[0],
234+
fill_value=inf_feedback_strength[0],
216235
)
217236
218237
# Sampling inf feedback and adjusting the shape
@@ -260,7 +279,7 @@ model1 = RtInfectionsRenewalModel(
260279
gen_int_rv=gen_int,
261280
I0_rv=I0,
262281
latent_infections_rv=latent_infections2,
263-
Rt_process_rv=rt,
282+
Rt_process_rv=MyRt(),
264283
infection_obs_process_rv=None,
265284
)
266285

docs/source/tutorials/hospital_admissions_model.qmd

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -186,37 +186,32 @@ gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int)
186186
187187
188188
class MyRt(metaclass.RandomVariable):
189-
def __init__(self, sd_rv):
190-
self.sd_rv = sd_rv
191189
192190
def validate(self):
193191
pass
194192
195-
def sample(self, n_steps: int, **kwargs) -> tuple:
196-
sd_rt, *_ = self.sd_rv()
193+
def sample(self, n: int, **kwargs) -> tuple:
194+
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
197195
198196
rt_rv = metaclass.TransformedRandomVariable(
199-
"Rt_rv",
200-
base_rv=process.SimpleRandomWalkProcess(
197+
name="log_rt_random_walk",
198+
base_rv=process.RandomWalk(
201199
name="log_rt",
202200
step_rv=metaclass.DistributionalRV(
203-
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
204-
),
205-
init_rv=metaclass.DistributionalRV(
206-
name="init_log_rt", distribution=dist.Normal(0, 0.2)
201+
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
207202
),
208203
),
209204
transforms=transformation.ExpTransform(),
210205
)
206+
rt_init_rv = metaclass.DistributionalRV(
207+
name="init_log_rt", distribution=dist.Normal(0, 0.2)
208+
)
209+
init_rt, *_ = rt_init_rv.sample()
211210
212-
return rt_rv.sample(n_steps=n_steps, **kwargs)
211+
return rt_rv.sample(n=n, init_vals=init_rt.value, **kwargs)
213212
214213
215-
rtproc = MyRt(
216-
metaclass.DistributionalRV(
217-
name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025)
218-
)
219-
)
214+
rtproc = MyRt()
220215
221216
# The observation model
222217

docs/source/tutorials/periodic_effects.qmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ from pyrenew import process, deterministic
2424

2525
```{python}
2626
# The random process for Rt
27-
rt_proc = process.RtWeeklyDiffProcess(
27+
rt_proc = process.RtWeeklyDiffARProcess(
2828
name="rt_weekly_diff",
2929
offset=0,
3030
log_rt_rv=deterministic.DeterministicVariable(
@@ -57,7 +57,7 @@ for i in range(0, 30, 7):
5757
plt.show()
5858
```
5959

60-
The implementation of the `RtWeeklyDiffProcess` (which is an instance of `RtPeriodicDiffProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.
60+
The implementation of the `RtWeeklyDiffARProcess` (which is an instance of `RtPeriodicDiffARProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven.
6161

6262
## Repeated sequences (tiling)
6363

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ sphinxcontrib-mermaid = "^0.9.2"
4141
sphinx-autodoc-typehints = "^2.1.0"
4242
sphinx-book-theme = "^1.1.2"
4343
ipywidgets = "^8.1.3"
44+
jupyter = "^1.0.0"
4445

4546
[tool.poetry.group.test]
4647
optional = true
@@ -49,6 +50,7 @@ optional = true
4950
pytest = "^8.3.2"
5051
pytest-cov = "^5.0.0"
5152
pytest-mpl = "^0.17.0"
53+
scipy = "^1.14.1"
5254

5355
[tool.numpydoc_validation]
5456
checks = [
@@ -81,4 +83,4 @@ build-backend = "poetry.core.masonry.api"
8183
known_first_party = ["pyrenew", "test"]
8284

8385
[tool.deptry.per_rule_ignores]
84-
DEP004 = ["pytest"]
86+
DEP004 = ["pytest", "scipy"]

src/pyrenew/deterministic/deterministic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from __future__ import annotations
55

6-
import jax.numpy as jnp
76
import numpyro
87
from jax.typing import ArrayLike
98

@@ -42,7 +41,7 @@ def __init__(
4241
"""
4342
self.name = name
4443
self.validate(value)
45-
self.value = jnp.atleast_1d(value)
44+
self.value = value
4645
self.set_timeseries(t_start, t_unit)
4746

4847
return None

src/pyrenew/latent/infection_initialization_method.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def initialize_infections(self, I_pre_init: ArrayLike):
8787
ArrayLike
8888
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
8989
"""
90+
I_pre_init = jnp.atleast_1d(I_pre_init)
9091
if self.n_timepoints < I_pre_init.size:
9192
raise ValueError(
9293
"I_pre_init must be no longer than n_timepoints. "
@@ -105,20 +106,23 @@ def initialize_infections(self, I_pre_init: ArrayLike):
105106
Parameters
106107
----------
107108
I_pre_init : ArrayLike
108-
An array with the same length as ``n_timepoints`` to be used as the initial infections.
109+
An array with the same length as ``n_timepoints`` to be
110+
used as the initial infections.
109111
110112
Returns
111113
-------
112114
ArrayLike
113-
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
115+
An array of length ``n_timepoints`` with the number of
116+
initialized infections at each time point.
114117
"""
118+
I_pre_init = jnp.array(I_pre_init)
115119
if I_pre_init.size != self.n_timepoints:
116120
raise ValueError(
117121
"I_pre_init must have the same size as n_timepoints. "
118122
f"Got I_pre_init of size {I_pre_init.size} "
119123
f"and n_timepoints of size {self.n_timepoints}."
120124
)
121-
return jnp.array(I_pre_init)
125+
return I_pre_init
122126

123127

124128
class InitializeInfectionsExponentialGrowth(InfectionInitializationMethod):
@@ -173,11 +177,12 @@ def initialize_infections(self, I_pre_init: ArrayLike):
173177
ArrayLike
174178
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
175179
"""
180+
I_pre_init = jnp.array(I_pre_init)
176181
if I_pre_init.size != 1:
177182
raise ValueError(
178183
f"I_pre_init must be an array of size 1. Got size {I_pre_init.size}."
179184
)
180-
rate = self.rate()[0].value
185+
rate = jnp.array(self.rate()[0].value)
181186
if rate.size != 1:
182187
raise ValueError(
183188
f"rate must be an array of size 1. Got size {rate.size}."

src/pyrenew/latent/infection_initialization_process.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# -*- coding: utf-8 -*-
22
# numpydoc ignore=GL08
3-
import numpyro
4-
53
from pyrenew.latent.infection_initialization_method import (
64
InfectionInitializationMethod,
75
)
@@ -97,7 +95,6 @@ def sample(self) -> tuple:
9795
infection_initialization = self.infection_init_method(
9896
I_pre_init.value,
9997
)
100-
numpyro.deterministic(self.name, infection_initialization)
10198

10299
return (
103100
SampledValue(

0 commit comments

Comments
 (0)