Skip to content

Commit 4898c1d

Browse files
authored
Remove validate() as a mandatory method for RandomVariable and Model (#745)
1 parent 079f9ca commit 4898c1d

23 files changed

Lines changed: 15 additions & 425 deletions

docs/tutorials/random_variables.qmd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ subclasses to implement:
191191
| Method | Signature | Purpose |
192192
|--------|-----------|---------|
193193
| `sample` | `sample(**kwargs) -> tuple` | Core computation: return a value, draw from a distribution, or perform a calculation |
194-
| `validate` | `validate(**kwargs) -> None` | Check that parameters are well-formed; raise an error if not |
195194

196195
The metaclass also provides:
197196

@@ -512,7 +511,6 @@ and weekday peaks that is characteristic of real hospital admissions data.
512511
### Writing a custom RandomVariable
513512

514513
1. Subclass `RandomVariable` from `pyrenew.metaclass`
515-
2. Implement `validate()` as a `@staticmethod`; call it in `__init__`
516514
3. Implement `sample(**kwargs)` returning a `tuple`
517515
4. Use `numpyro.sample()` for quantities to be estimated
518516
5. Use `numpyro.deterministic()` to record derived quantities in the trace

pyrenew/deterministic/deterministic.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,39 +33,10 @@ def __init__(
3333
None
3434
"""
3535
super().__init__(name=name)
36-
self.validate(value)
3736
self.value = value
3837

3938
return None
4039

41-
@staticmethod
42-
def validate(value: ArrayLike) -> None:
43-
"""
44-
Validates input to DeterministicVariable
45-
46-
Parameters
47-
----------
48-
value
49-
An ArrayLike object.
50-
51-
Returns
52-
-------
53-
None
54-
55-
Raises
56-
------
57-
Exception
58-
If the input value object is not an ArrayLike object.
59-
"""
60-
if not isinstance(value, ArrayLike):
61-
raise ValueError(
62-
f"value {value} passed to a DeterministicVariable "
63-
f"is of type {type(value).__name__}, expected "
64-
"an ArrayLike object"
65-
)
66-
67-
return None
68-
6940
def sample(
7041
self,
7142
record: bool = False,

pyrenew/deterministic/deterministicpmf.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,6 @@ def __init__(
5454

5555
return None
5656

57-
@staticmethod
58-
def validate(value: ArrayLike) -> None:
59-
"""
60-
Validates input to DeterministicPMF
61-
62-
Parameters
63-
----------
64-
value
65-
An ArrayLike object.
66-
67-
Returns
68-
-------
69-
None
70-
"""
71-
return None
72-
7357
def sample(
7458
self,
7559
**kwargs: object,

pyrenew/deterministic/nullrv.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,9 @@ def __init__(self) -> None:
1919
None
2020
"""
2121
RandomVariable.__init__(self, name="null")
22-
self.validate()
2322

2423
return None
2524

26-
@staticmethod
27-
def validate() -> None:
28-
"""
29-
Not used
30-
31-
Returns
32-
-------
33-
None
34-
"""
35-
return None
36-
3725
def sample(
3826
self,
3927
**kwargs: object,
@@ -64,21 +52,9 @@ def __init__(self) -> None:
6452
None
6553
"""
6654
RandomVariable.__init__(self, name="null_observation")
67-
self.validate()
6855

6956
return None
7057

71-
@staticmethod
72-
def validate() -> None:
73-
"""
74-
Not used
75-
76-
Returns
77-
-------
78-
None
79-
"""
80-
return None
81-
8258
def sample(
8359
self,
8460
mu: ArrayLike,

pyrenew/latent/base.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -286,24 +286,6 @@ def get_required_lookback(self) -> int:
286286
"""
287287
return len(self.gen_int_rv())
288288

289-
@abstractmethod
290-
def validate(self) -> None:
291-
"""
292-
Validate latent process parameters.
293-
294-
Subclasses must implement this method to validate all parameters specific
295-
to their implementation (e.g., temporal process parameters, I0 parameters).
296-
297-
Common validation (n_initialization_points, gen_int_rv) is performed in
298-
__init__. Population structure validation is performed at sample time.
299-
300-
Raises
301-
------
302-
ValueError
303-
If any parameters fail validation
304-
"""
305-
pass # pragma: no cover
306-
307289
@abstractmethod
308290
def sample(
309291
self,

pyrenew/latent/hierarchical_priors.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ def __init__(
5151
super().__init__(name=name)
5252
self.sd_rv = sd_rv
5353

54-
def validate(self) -> None:
55-
"""Validate the random variable (no-op for this class)."""
56-
pass
57-
5854
def sample(self, n_groups: int, **kwargs: object) -> ArrayLike:
5955
"""
6056
Sample group-level effects.
@@ -136,10 +132,6 @@ def __init__(
136132
self.sd_concentration_rv = sd_concentration_rv
137133
self.sd_min = sd_min
138134

139-
def validate(self) -> None:
140-
"""Validate the random variable (no-op for this class)."""
141-
pass
142-
143135
def sample(self, n_groups: int, **kwargs: object) -> ArrayLike:
144136
"""
145137
Sample group-level standard deviations.
@@ -220,10 +212,6 @@ def __init__(
220212
self.sd_rv = sd_rv
221213
self.df_rv = df_rv
222214

223-
def validate(self) -> None:
224-
"""Validate the random variable (no-op for this class)."""
225-
pass
226-
227215
def sample(self, n_groups: int, **kwargs: object) -> ArrayLike:
228216
"""
229217
Sample group-level modes.

pyrenew/latent/infection_initialization_process.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pyrenew.latent.infection_initialization_method import (
77
InfectionInitializationMethod,
88
)
9-
from pyrenew.metaclass import RandomVariable, _assert_type
9+
from pyrenew.metaclass import RandomVariable
1010

1111

1212
class InfectionInitializationProcess(RandomVariable):
@@ -33,37 +33,10 @@ def __init__(
3333
-------
3434
None
3535
"""
36-
InfectionInitializationProcess.validate(I_pre_init_rv, infection_init_method)
37-
3836
super().__init__(name=name)
3937
self.I_pre_init_rv = I_pre_init_rv
4038
self.infection_init_method = infection_init_method
4139

42-
@staticmethod
43-
def validate(
44-
I_pre_init_rv: RandomVariable,
45-
infection_init_method: InfectionInitializationMethod,
46-
) -> None:
47-
"""Validate the input arguments to the InfectionInitializationProcess class constructor
48-
49-
Parameters
50-
----------
51-
I_pre_init_rv
52-
A random variable representing the number of infections that occur at some time before the renewal process begins.
53-
infection_init_method
54-
An method to generate the initial infections.
55-
56-
Returns
57-
-------
58-
None
59-
"""
60-
_assert_type("I_pre_init_rv", I_pre_init_rv, RandomVariable)
61-
_assert_type(
62-
"infection_init_method",
63-
infection_init_method,
64-
InfectionInitializationMethod,
65-
)
66-
6740
def sample(self) -> ArrayLike:
6841
"""Sample the Infection Initialization Process.
6942

pyrenew/latent/infections.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ def __init__(self, name: str) -> None:
6060
"""
6161
super().__init__(name=name)
6262

63-
@staticmethod
64-
def validate() -> None: # numpydoc ignore=GL08
65-
return None
66-
6763
def sample(
6864
self,
6965
Rt: ArrayLike,

pyrenew/metaclass.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,6 @@
1010
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample
1111

1212

13-
def _assert_type(arg_name: str, value: object, expected_type: type) -> None:
14-
"""
15-
Matches TypeError arising during validation
16-
17-
Parameters
18-
----------
19-
arg_name
20-
Name of the argument
21-
value
22-
The object to be validated
23-
expected_type
24-
The expected object type
25-
26-
Raises
27-
------
28-
TypeError
29-
If `value` is not an instance of `expected_type`.
30-
31-
Returns
32-
-------
33-
None
34-
"""
35-
36-
if not isinstance(value, expected_type):
37-
raise TypeError(
38-
f"{arg_name} must be an instance of {expected_type}. Got {type(value)}"
39-
)
40-
41-
4213
class RandomVariable(metaclass=ABCMeta):
4314
"""
4415
Abstract base class for latent and observed random variables.
@@ -93,14 +64,6 @@ def sample(
9364
"""
9465
pass
9566

96-
@staticmethod
97-
@abstractmethod
98-
def validate(**kwargs: object) -> None:
99-
"""
100-
Validation of kwargs to be implemented in subclasses.
101-
"""
102-
pass
103-
10467
def __call__(self, **kwargs: object) -> tuple:
10568
"""
10669
Alias for `sample`.
@@ -119,11 +82,6 @@ class Model(metaclass=ABCMeta):
11982
def __init__(self, **kwargs: object) -> None: # numpydoc ignore=GL08
12083
pass
12184

122-
@staticmethod
123-
@abstractmethod
124-
def validate() -> None: # numpydoc ignore=GL08
125-
pass
126-
12785
@abstractmethod
12886
def sample(
12987
self,

pyrenew/observation/base.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class BaseObservationProcess(RandomVariable):
3636
3737
Subclasses must implement:
3838
39-
- ``validate()``: Validate parameters (call ``_validate_pmf()`` for PMFs)
4039
- ``lookback_days()``: Return PMF length for initialization
4140
- ``infection_resolution()``: Return ``"aggregate"`` or ``"subpop"``
4241
- ``_predicted_obs()``: Transform infections to predicted values
@@ -79,22 +78,6 @@ def __init__(self, name: str, temporal_pmf_rv: RandomVariable) -> None:
7978
super().__init__(name=name)
8079
self.temporal_pmf_rv = temporal_pmf_rv
8180

82-
@abstractmethod
83-
def validate(self) -> None:
84-
"""
85-
Validate observation process parameters.
86-
87-
Subclasses must implement this method to validate all parameters.
88-
Typically this involves calling ``_validate_pmf()`` for the PMF
89-
and adding any additional parameter-specific validation.
90-
91-
Raises
92-
------
93-
ValueError
94-
If any parameters fail validation.
95-
"""
96-
pass # pragma: no cover
97-
9881
@abstractmethod
9982
def lookback_days(self) -> int:
10083
"""

0 commit comments

Comments
 (0)