3636
3737if TYPE_CHECKING :
3838 from pysatl_core .distributions .support import Support
39- from pysatl_core .types import Number , NumberParameter , NumericArray
39+ from pysatl_core .types import Number , NumericArray
4040
4141 type ParametrizedFunction = Callable [[Parametrization , Any ], Any ]
4242 type SupportArg = Callable [[Parametrization ], Support | None ] | None
@@ -52,10 +52,10 @@ class ExponentialFamilyParametrization(Parametrization):
5252 f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
5353
5454 Attributes:
55- theta (NumberParameter ): Natural parameter vector (can be a scalar or array)
55+ theta (NumericArray ): Natural parameter vector (can be a scalar or array)
5656 """
5757
58- theta : NumberParameter
58+ theta : NumericArray
5959
6060 def transform_to_base_parametrization (self ) -> ExponentialFamilyParametrization :
6161 """Return the base parametrization (identity transform for canonical form)."""
@@ -120,9 +120,9 @@ class ContinuousExponentialClassFamily(ParametricFamily):
120120 def __init__ (
121121 self ,
122122 * ,
123- log_partition : Callable [[NumberParameter ], NumberParameter ],
124- sufficient_statistics : Callable [[NumberParameter ], NumberParameter ],
125- normalization_constant : Callable [[NumberParameter ], NumberParameter ],
123+ log_partition : Callable [[NumericArray ], NumericArray ],
124+ sufficient_statistics : Callable [[NumericArray ], NumericArray ],
125+ normalization_constant : Callable [[NumericArray ], Number ],
126126 support : SupportByPredicate ,
127127 parameter_space : SupportByPredicate ,
128128 sufficient_statistics_values : SupportByPredicate ,
@@ -185,13 +185,13 @@ def log_density(self) -> ParametrizedFunction:
185185 and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support.
186186
187187 Returns:
188- Callable[[Parametrization, NumberParameter ], Number]
188+ Callable[[Parametrization, NumericArray ], Number]
189189 """
190190
191- def log_density_func (parametrization : Parametrization , x : NumberParameter ) -> Number :
191+ def log_density_func (parametrization : Parametrization , x : NumericArray ) -> Number :
192192 parametrization = cast (ExponentialFamilyParametrization , parametrization )
193193 parametrization = parametrization .transform_to_base_parametrization ()
194- if x not in self ._support :
194+ if np . array ([ x ]) not in self ._support :
195195 return - np .inf
196196
197197 theta = parametrization .theta
@@ -211,7 +211,7 @@ def density(self) -> ParametrizedFunction:
211211 Density function (exponentiated log‑density).
212212
213213 Returns:
214- Callable[[Parametrization, NumberParameter ], Number]
214+ Callable[[Parametrization, NumericArray ], Number]
215215 """
216216 return lambda parametrization , x : np .exp (self .log_density (parametrization , x ))
217217
@@ -230,8 +230,8 @@ def conjugate_prior_family(self) -> ContinuousExponentialClassFamily:
230230 """
231231
232232 def conjugate_sufficient (
233- theta : NumberParameter ,
234- ) -> NumberParameter :
233+ theta : NumericArray ,
234+ ) -> NumericArray :
235235 if not hasattr (theta , "__len__" ):
236236 theta = np .array ([theta ])
237237
@@ -240,9 +240,9 @@ def conjugate_sufficient(
240240 return np .append (theta , self ._log_partition (theta ))
241241
242242 def conjugate_log_partition (
243- parametrization : NumberParameter ,
244- ) -> NumberParameter :
245- def pdf (theta : NumberParameter ) -> NumberParameter :
243+ parametrization : NumericArray ,
244+ ) -> NumericArray :
245+ def pdf (theta : NumericArray ) -> Number :
246246 if not hasattr (theta , "__len__" ):
247247 theta = np .array ([theta ])
248248 return cast (
@@ -259,23 +259,25 @@ def pdf(theta: NumberParameter) -> NumberParameter:
259259 lambda x : pdf (x ) if x in self ._parameter_space else 0 , # type: ignore[arg-type]
260260 [(float ("-inf" ), float ("+inf" ))],
261261 )[0 ]
262- return cast (np .float64 , - np .log (all_value ))
262+ return np . array ([ cast (np .float64 , - np .log (all_value ))] )
263263
264264 def conjugate_sufficient_accepts (
265265 theta : NumericArray ,
266266 ) -> bool :
267267 xi = theta [:- 1 ]
268268 nu = theta [- 1 ]
269269
270- return xi in self ._sufficient_statistics_values and nu in ContinuousSupport (0 , np .inf )
270+ return xi in self ._sufficient_statistics_values and np .array ([nu ]) in ContinuousSupport (
271+ 0 , np .inf
272+ )
271273
272274 return ContinuousExponentialClassFamily (
273275 log_partition = conjugate_log_partition ,
274276 sufficient_statistics = conjugate_sufficient ,
275277 normalization_constant = lambda _ : 1 ,
276278 support = self ._parameter_space ,
277- sufficient_statistics_values = self ._parameter_space , # TODO: write convex hull for this
278- parameter_space = SupportByPredicate (predicate = conjugate_sufficient_accepts ), # type: ignore[arg-type]
279+ sufficient_statistics_values = self ._parameter_space ,
280+ parameter_space = SupportByPredicate (predicate = conjugate_sufficient_accepts ),
279281 name = self .name ,
280282 distr_type = self ._distr_type ,
281283 distr_parametrizations = self .parametrization_names ,
@@ -284,7 +286,7 @@ def conjugate_sufficient_accepts(
284286
285287 def transform (
286288 self ,
287- transform_function : Callable [[NumberParameter ], NumberParameter ],
289+ transform_function : Callable [[NumericArray ], NumericArray ],
288290 ) -> ContinuousExponentialClassFamily :
289291 """
290292 Transform the random variable by a monotonic, differentiable function.
@@ -301,20 +303,20 @@ def transform(
301303 ContinuousExponentialClassFamily: A new family for the transformed variable.
302304 """
303305
304- def calculate_jacobian (x : NumberParameter ) -> NumberParameter :
306+ def calculate_jacobian (x : NumericArray ) -> NumericArray :
305307 if not isinstance (x , Iterable ):
306308 x = np .array ([x ])
307309
308310 return np .abs (det (jacobian (transform_function , x ).df ))
309311
310- def new_support (x : NumberParameter ) -> bool :
312+ def new_support (x : NumericArray ) -> bool :
311313 return transform_function (x ) in self ._support
312314
313- def new_sufficient (x : NumberParameter ) -> NumberParameter :
315+ def new_sufficient (x : NumericArray ) -> NumericArray :
314316 return self ._sufficient (transform_function (x ))
315317
316- def new_normalization (x : NumberParameter ) -> NumberParameter :
317- return self ._normalization (x ) * calculate_jacobian (x )
318+ def new_normalization (x : NumericArray ) -> Number :
319+ return cast ( np . float64 , self ._normalization (x ) * calculate_jacobian (x ) )
318320
319321 return ContinuousExponentialClassFamily (
320322 log_partition = self ._log_partition ,
@@ -339,7 +341,11 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any:
339341 if hasattr (x , "__len__" ):
340342 dimension_size = len (x )
341343 return nquad (
342- lambda x : np .dot (x , self .density (parametrization , x )) if x in self ._support else 0 ,
344+ lambda x : (
345+ np .dot (x , self .density (parametrization , x ))
346+ if np .array ([x ]) in self ._support
347+ else 0
348+ ),
343349 [(float ("-inf" ), float ("inf" ))] * dimension_size ,
344350 )[0 ]
345351
@@ -355,7 +361,9 @@ def func(parametrization: Parametrization, x: Any) -> Any:
355361 if hasattr (x , "__len__" ):
356362 dimension_size = len (x )
357363 return nquad (
358- lambda x : x ** 2 * self .density (parametrization , x ) if x in self ._support else 0 ,
364+ lambda x : (
365+ x ** 2 * self .density (parametrization , x ) if np .array ([x ]) in self ._support else 0
366+ ),
359367 [(float ("-inf" ), float ("inf" ))] * dimension_size ,
360368 )[0 ]
361369
@@ -394,7 +402,7 @@ def posterior_hyperparameters(
394402 posterior_effective_sample_size = parametrizaiton .effective_sample_size
395403 if hasattr (sample , "__iter__" ) and not isinstance (sample , str ):
396404 posterior_effective_suff_stat_value += np .sum (
397- [self ._sufficient (x ) for x in sample ], # type: ignore[arg-type]
405+ [self ._sufficient (x ) for x in sample ],
398406 axis = 0 ,
399407 )
400408 posterior_effective_sample_size += len (sample )
@@ -424,13 +432,13 @@ def posterior_predictive(self) -> ParametricFamily:
424432
425433 def conjugate_log_partition (
426434 parametrization : ExponentialConjugateHyperparameters ,
427- ) -> NumberParameter :
435+ ) -> NumericArray :
428436 conjugate_value = self .conjugate_prior_family ._log_partition (
429437 parametrization .transform_to_base_parametrization ().theta
430438 )
431439 return np .exp (conjugate_value )
432440
433- def posterior_density (parametrization : Parametrization , x : NumberParameter ) -> Number :
441+ def posterior_density (parametrization : Parametrization , x : NumericArray ) -> Number :
434442 parametrization = cast (ExponentialConjugateHyperparameters , parametrization )
435443 return cast (
436444 np .float32 ,
0 commit comments