Skip to content

Commit 073aea2

Browse files
authored
Move and rename VectorizedVariable (#740)
1 parent 4fc516c commit 073aea2

10 files changed

Lines changed: 114 additions & 111 deletions

docs/tutorials/observation_processes_measurements.qmd

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@ from _tutorial_theme import theme_tutorial
2424
from pyrenew.observation import (
2525
Measurements,
2626
HierarchicalNormalNoise,
27-
VectorizedRV,
2827
)
29-
from pyrenew.randomvariable import DistributionalVariable
28+
from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable
3029
from pyrenew.deterministic import DeterministicVariable, DeterministicPMF
3130
```
3231

@@ -83,18 +82,18 @@ Measurement data typically exhibits **sensor-level variability**: different inst
8382
observed ~ Normal(predicted + sensor_mode[sensor], sensor_sd[sensor])
8483
```
8584

86-
The sensor-level RVs must implement `sample(n_groups=...)`. Use `VectorizedRV` to wrap simple distributions:
85+
The sensor-level RVs must implement `sample(n_groups=...)`. Use `VectorizedVariable` to wrap simple distributions:
8786

8887
```{python}
8988
# | label: noise-model-general
9089
# Sensor modes: zero-centered, allowing positive or negative bias
91-
sensor_mode_rv = VectorizedRV(
90+
sensor_mode_rv = VectorizedVariable(
9291
"vec_sensor_mode",
9392
DistributionalVariable("sensor_mode", dist.Normal(0, 0.5)),
9493
)
9594
9695
# Sensor SDs: must be positive, truncated normal is a common choice
97-
sensor_sd_rv = VectorizedRV(
96+
sensor_sd_rv = VectorizedVariable(
9897
"vec_sensor_sd",
9998
DistributionalVariable(
10099
"sensor_sd", dist.TruncatedNormal(loc=0.3, scale=0.15, low=0.05)
@@ -355,13 +354,13 @@ For wastewater, a "sensor" is a WWTP/lab pair—the combination of treatment pla
355354
```{python}
356355
# | label: ww-noise-model
357356
# Sensor-level mode: systematic differences between WWTP/lab pairs
358-
ww_sensor_mode_rv = VectorizedRV(
357+
ww_sensor_mode_rv = VectorizedVariable(
359358
"vec_ww_sensor_mode",
360359
DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)),
361360
)
362361
363362
# Sensor-level SD: measurement variability within each WWTP/lab pair
364-
ww_sensor_sd_rv = VectorizedRV(
363+
ww_sensor_sd_rv = VectorizedVariable(
365364
"vec_ww_sensor_sd",
366365
DistributionalVariable(
367366
"ww_sensor_sd", dist.TruncatedNormal(loc=0.3, scale=0.15, low=0.10)

pyrenew/observation/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
MeasurementNoise,
2929
NegativeBinomialNoise,
3030
PoissonNoise,
31-
VectorizedRV,
3231
)
3332
from pyrenew.observation.types import ObservationSample
3433

@@ -44,7 +43,6 @@
4443
"NegativeBinomialNoise",
4544
"MeasurementNoise",
4645
"HierarchicalNormalNoise",
47-
"VectorizedRV",
4846
# Observation processes
4947
"Counts",
5048
"CountsBySubpop",

pyrenew/observation/noise.py

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# numpydoc ignore=GL08
21
"""
32
Noise models for observation processes.
43
@@ -16,10 +15,6 @@
1615
- ``HierarchicalNormalNoise``: Normal noise with hierarchical sensor effects.
1716
Takes ``sensor_mode_rv`` and ``sensor_sd_rv`` for sensor-level
1817
bias and variability.
19-
20-
**Utilities**
21-
22-
- ``VectorizedRV``: Wrapper that adds ``n_groups`` support to simple RVs.
2318
"""
2419

2520
from __future__ import annotations
@@ -36,60 +31,6 @@
3631
_EPSILON = 1e-10
3732

3833

39-
class VectorizedRV(RandomVariable):
40-
"""
41-
Wrapper that adds n_groups support to simple RandomVariables.
42-
43-
Uses numpyro.plate to vectorize sampling, enabling simple RVs
44-
to work with noise models expecting the group-level interface.
45-
46-
Parameters
47-
----------
48-
name
49-
A name for this random variable.
50-
The numpyro plate is named ``f"{name}_plate"``.
51-
rv
52-
The underlying RandomVariable to wrap.
53-
"""
54-
55-
def __init__(self, name: str, rv: RandomVariable) -> None:
56-
"""
57-
Initialize VectorizedRV wrapper.
58-
59-
Parameters
60-
----------
61-
name
62-
A name for this random variable.
63-
The numpyro plate is named ``f"{name}_plate"``.
64-
rv
65-
The underlying RandomVariable to wrap.
66-
"""
67-
super().__init__(name=name)
68-
self.rv = rv
69-
self.plate_name = f"{name}_plate"
70-
71-
def validate(self) -> None: # pragma: no cover
72-
"""Validate the underlying RV."""
73-
self.rv.validate()
74-
75-
def sample(self, n_groups: int, **kwargs: object) -> ArrayLike:
76-
"""
77-
Sample n_groups values using numpyro.plate.
78-
79-
Parameters
80-
----------
81-
n_groups
82-
Number of group-level values to sample.
83-
84-
Returns
85-
-------
86-
ArrayLike
87-
Array of shape (n_groups,).
88-
"""
89-
with numpyro.plate(self.plate_name, n_groups):
90-
return self.rv(**kwargs)
91-
92-
9334
class CountNoise(ABC):
9435
"""
9536
Abstract base for count observation noise models.
@@ -356,7 +297,8 @@ class HierarchicalNormalNoise(MeasurementNoise):
356297
357298
Notes
358299
-----
359-
Use ``VectorizedRV`` to wrap simple RVs that lack this interface.
300+
Use [`VectorizedVariable`][pyrenew.randomvariable.VectorizedVariable]
301+
to wrap simple RVs that lack this interface.
360302
"""
361303

362304
def __init__(

pyrenew/randomvariable/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
StaticDistributionalVariable,
77
)
88
from pyrenew.randomvariable.transformedvariable import TransformedVariable
9+
from pyrenew.randomvariable.vectorizedvariable import VectorizedVariable
910

1011
__all__ = [
1112
"DistributionalVariable",
1213
"StaticDistributionalVariable",
1314
"DynamicDistributionalVariable",
1415
"TransformedVariable",
16+
"VectorizedVariable",
1517
]
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""
2+
Vectorization wrapper for simple RandomVariables
3+
"""
4+
5+
import numpyro
6+
from jax.typing import ArrayLike
7+
8+
from pyrenew.metaclass import RandomVariable
9+
10+
11+
class VectorizedVariable(RandomVariable):
12+
"""
13+
Wrapper that adds n_groups support to simple RandomVariables.
14+
15+
Uses numpyro.plate to vectorize sampling, enabling simple RVs
16+
to work with noise models expecting the group-level interface.
17+
"""
18+
19+
def __init__(self, name: str, rv: RandomVariable) -> None:
20+
"""
21+
Initialize VectorizedVariable wrapper.
22+
23+
Parameters
24+
----------
25+
name
26+
A name for this random variable.
27+
The numpyro plate used to vectorize will
28+
have this name with the suffix `_plate"`.
29+
rv
30+
The underlying RandomVariable to wrap.
31+
"""
32+
super().__init__(name=name)
33+
self.rv = rv
34+
self.plate_name = f"{name}_plate"
35+
36+
def validate(self) -> None: # pragma: no cover
37+
"""Validate the underlying RV."""
38+
self.rv.validate()
39+
40+
def sample(self, n_groups: int, **kwargs: object) -> ArrayLike:
41+
"""
42+
Sample n_groups values using numpyro.plate.
43+
44+
Parameters
45+
----------
46+
n_groups
47+
Number of group-level values to sample.
48+
49+
Returns
50+
-------
51+
ArrayLike
52+
Array of shape (n_groups,).
53+
"""
54+
with numpyro.plate(self.plate_name, n_groups):
55+
return self.rv(**kwargs)

test/conftest.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
Counts,
1616
HierarchicalNormalNoise,
1717
NegativeBinomialNoise,
18-
VectorizedRV,
1918
)
20-
from pyrenew.randomvariable import DistributionalVariable
19+
from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable
2120

2221
# =============================================================================
2322
# PMF Fixtures
@@ -103,18 +102,18 @@ def gen_int_rv():
103102
@pytest.fixture
104103
def hierarchical_normal_noise():
105104
"""
106-
Standard HierarchicalNormalNoise with VectorizedRV wrappers.
105+
Standard HierarchicalNormalNoise with VectorizedVariable wrappers.
107106
108107
Returns
109108
-------
110109
HierarchicalNormalNoise
111110
Noise model for continuous measurements.
112111
"""
113-
sensor_mode_rv = VectorizedRV(
112+
sensor_mode_rv = VectorizedVariable(
114113
name="sensor_mode_rv",
115114
rv=DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.5)),
116115
)
117-
sensor_sd_rv = VectorizedRV(
116+
sensor_sd_rv = VectorizedVariable(
118117
name="sensor_sd_rv",
119118
rv=DistributionalVariable(
120119
"ww_sensor_sd", dist.TruncatedNormal(0.3, 0.15, low=0.10)
@@ -133,11 +132,11 @@ def hierarchical_normal_noise_tight():
133132
HierarchicalNormalNoise
134133
Noise model with very small variance.
135134
"""
136-
sensor_mode_rv = VectorizedRV(
135+
sensor_mode_rv = VectorizedVariable(
137136
name="sensor_mode_rv",
138137
rv=DistributionalVariable("ww_sensor_mode", dist.Normal(0, 0.01)),
139138
)
140-
sensor_sd_rv = VectorizedRV(
139+
sensor_sd_rv = VectorizedVariable(
141140
name="sensor_sd_rv",
142141
rv=DistributionalVariable(
143142
"ww_sensor_sd", dist.TruncatedNormal(0.01, 0.005, low=0.001)

test/test_interface_coverage.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,16 @@
3838
NegativeBinomialNoise,
3939
NegativeBinomialObservation,
4040
PoissonNoise,
41-
VectorizedRV,
4241
)
4342
from pyrenew.process import ARProcess, DifferencedProcess
4443
from pyrenew.process.iidrandomsequence import IIDRandomSequence, StandardNormalSequence
4544
from pyrenew.process.randomwalk import RandomWalk as ProcessRandomWalk
4645
from pyrenew.process.randomwalk import StandardNormalRandomWalk
47-
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
46+
from pyrenew.randomvariable import (
47+
DistributionalVariable,
48+
TransformedVariable,
49+
VectorizedVariable,
50+
)
4851
from test.test_helpers import ConcreteMeasurements
4952

5053
# =============================================================================
@@ -92,11 +95,11 @@ def _make_measurements():
9295
-------
9396
instantiated object
9497
"""
95-
sensor_mode_rv = VectorizedRV(
98+
sensor_mode_rv = VectorizedVariable(
9699
name="sensor_mode_rv",
97100
rv=DistributionalVariable("mode", dist.Normal(0, 0.5)),
98101
)
99-
sensor_sd_rv = VectorizedRV(
102+
sensor_sd_rv = VectorizedVariable(
100103
name="sensor_sd_rv",
101104
rv=DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.1)),
102105
)
@@ -115,11 +118,11 @@ def _make_hierarchical_normal_noise():
115118
-------
116119
instantiated object
117120
"""
118-
sensor_mode_rv = VectorizedRV(
121+
sensor_mode_rv = VectorizedVariable(
119122
name="sensor_mode_rv",
120123
rv=DistributionalVariable("mode", dist.Normal(0, 0.5)),
121124
)
122-
sensor_sd_rv = VectorizedRV(
125+
sensor_sd_rv = VectorizedVariable(
123126
name="sensor_sd_rv",
124127
rv=DistributionalVariable("sd", dist.TruncatedNormal(0.3, 0.15, low=0.1)),
125128
)
@@ -455,12 +458,12 @@ def test_random_variable_rejects_invalid_name(bad_name):
455458
pytest.param(_make_counts_by_subpop(), "test_subpop", id="CountsBySubpop"),
456459
pytest.param(_make_measurements(), "test_ww", id="ConcreteMeasurements"),
457460
pytest.param(
458-
VectorizedRV(
461+
VectorizedVariable(
459462
name="test_vec",
460463
rv=DistributionalVariable("inner", dist.Normal(0, 1)),
461464
),
462465
"test_vec",
463-
id="VectorizedRV",
466+
id="VectorizedVariable",
464467
),
465468
],
466469
)

test/test_observation_measurements.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,11 @@
1111
import pytest
1212

1313
from pyrenew.deterministic import DeterministicPMF
14-
from pyrenew.observation import (
15-
HierarchicalNormalNoise,
16-
VectorizedRV,
17-
)
18-
from pyrenew.randomvariable import DistributionalVariable
14+
from pyrenew.observation import HierarchicalNormalNoise
15+
from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable
1916
from test.test_helpers import ConcreteMeasurements
2017

2118

22-
class TestVectorizedRV:
23-
"""Test VectorizedRV wrapper class."""
24-
25-
def test_init_and_sample(self):
26-
"""Test VectorizedRV initialization and sampling."""
27-
rv = DistributionalVariable("test", dist.Normal(0, 1.0))
28-
vectorized = VectorizedRV(name="test_vectorized", rv=rv)
29-
30-
with numpyro.handlers.seed(rng_seed=42):
31-
samples = vectorized.sample(n_groups=5)
32-
33-
assert samples.shape == (5,)
34-
# Verify samples are actually different (not degenerate)
35-
assert jnp.std(samples) > 0
36-
37-
3819
class TestHierarchicalNormalNoise:
3920
"""Test HierarchicalNormalNoise model."""
4021

@@ -235,11 +216,11 @@ def test_sensor_bias_differences(self):
235216
shedding_pmf = jnp.array([1.0])
236217

237218
# Use wide priors to ensure sensors get distinguishable biases
238-
sensor_mode_rv = VectorizedRV(
219+
sensor_mode_rv = VectorizedVariable(
239220
name="sensor_mode_rv",
240221
rv=DistributionalVariable("mode", dist.Normal(0, 2.0)),
241222
)
242-
sensor_sd_rv = VectorizedRV(
223+
sensor_sd_rv = VectorizedVariable(
243224
name="sensor_sd_rv",
244225
rv=DistributionalVariable("sd", dist.TruncatedNormal(0.1, 0.05, low=0.01)),
245226
)

0 commit comments

Comments
 (0)