File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 55import jax .numpy as jnp
66from numpy .typing import ArrayLike
77
8- import pyrenew .arrayutils as au
98import pyrenew .latent .infection_functions as inf
109from pyrenew .metaclass import RandomVariable
1110
@@ -168,23 +167,17 @@ def sample(
168167 )
169168 )
170169
171- if inf_feedback_strength .ndim == Rt .ndim - 1 :
172- inf_feedback_strength = inf_feedback_strength [jnp .newaxis ]
173-
174- # Making sure inf_feedback_strength spans the Rt length
175- if inf_feedback_strength .shape [0 ] == 1 :
176- inf_feedback_strength = au .pad_edges_to_match (
177- x = inf_feedback_strength ,
178- y = Rt ,
179- axis = 0 ,
180- )[0 ]
181- if inf_feedback_strength .shape != Rt .shape :
182- raise ValueError (
183- "Infection feedback strength must be of length 1 "
184- "or the same length as the reproduction number array. "
185- f"Got { inf_feedback_strength .shape } "
186- f"and { Rt .shape } respectively."
170+ try :
171+ inf_feedback_strength = jnp .broadcast_to (
172+ inf_feedback_strength , Rt .shape
187173 )
174+ except Exception as e :
175+ raise ValueError (
176+ "Could not broadcast inf_feedback_strength "
177+ f"(shape { inf_feedback_strength .shape } ) "
178+ "to the shape of Rt"
179+ f"{ Rt .shape } "
180+ ) from e
188181
189182 # Sampling inf feedback pmf
190183 inf_feedback_pmf = self .infection_feedback_pmf (** kwargs )
You can’t perform that action at this time.
0 commit comments