Skip to content

Commit 32e80fd

Browse files
authored
refactor infectionswithfeedback.py to allow shared infection feedback strength across sites (#470)
1 parent 9a996a3 commit 32e80fd

1 file changed

Lines changed: 10 additions & 17 deletions

File tree

pyrenew/latent/infectionswithfeedback.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import jax.numpy as jnp
66
from numpy.typing import ArrayLike
77

8-
import pyrenew.arrayutils as au
98
import pyrenew.latent.infection_functions as inf
109
from pyrenew.metaclass import RandomVariable
1110

@@ -168,23 +167,17 @@ def sample(
168167
)
169168
)
170169

171-
if inf_feedback_strength.ndim == Rt.ndim - 1:
172-
inf_feedback_strength = inf_feedback_strength[jnp.newaxis]
173-
174-
# Making sure inf_feedback_strength spans the Rt length
175-
if inf_feedback_strength.shape[0] == 1:
176-
inf_feedback_strength = au.pad_edges_to_match(
177-
x=inf_feedback_strength,
178-
y=Rt,
179-
axis=0,
180-
)[0]
181-
if inf_feedback_strength.shape != Rt.shape:
182-
raise ValueError(
183-
"Infection feedback strength must be of length 1 "
184-
"or the same length as the reproduction number array. "
185-
f"Got {inf_feedback_strength.shape} "
186-
f"and {Rt.shape} respectively."
170+
try:
171+
inf_feedback_strength = jnp.broadcast_to(
172+
inf_feedback_strength, Rt.shape
187173
)
174+
except Exception as e:
175+
raise ValueError(
176+
"Could not broadcast inf_feedback_strength "
177+
f"(shape {inf_feedback_strength.shape}) "
178+
"to the shape of Rt"
179+
f"{Rt.shape}"
180+
) from e
188181

189182
# Sampling inf feedback pmf
190183
inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)

0 commit comments

Comments
 (0)