1+ """
2+ Exponential family distributions in continuous spaces.
3+
4+ This module implements the continuous exponential family of probability distributions,
5+ their conjugate priors, posterior inference, and posterior predictive distributions.
6+ """
7+
18from __future__ import annotations
29
310__author__ = "Vinogradov Ilya"
411__copyright__ = "Copyright (c) 2025 PySATL project"
512__license__ = "SPDX-License-Identifier: MIT"
613
7-
814from collections .abc import Callable , Iterable
915from dataclasses import dataclass
1016from typing import TYPE_CHECKING , Any , cast
3945@dataclass
4046class ExponentialFamilyParametrization (Parametrization ):
4147 """
42- Standard parametrization of Exponential Family.
48+ Standard parametrization of an exponential family distribution.
49+
50+ This parametrization uses the natural (canonical) parameter vector `theta`
51+ The density is expressed as:
52+ f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
53+
54+ Attributes:
55+ theta (NumberParameter): Natural parameter vector (can be a scalar or array)
4356 """
4457
4558 theta : NumberParameter
4659
4760 def transform_to_base_parametrization (self ) -> ExponentialFamilyParametrization :
61+ """Return the base parametrization (identity transform for canonical form)."""
4862 return self
4963
5064
5165@dataclass
5266class ExponentialConjugateHyperparameters (Parametrization ):
67+ """
68+ Hyperparameters for the conjugate prior of an exponential family
69+
70+ For a prior of the form:
71+ p(θ) ∝ exp(ν₀ᵀ T(θ) + n₀ A(θ))
72+ the hyperparameters are:
73+ effective_suff_stat_value = ν₀
74+ effective_sample_size = n₀
75+
76+ Attributes:
77+ effective_suff_stat_value (NumericArray): Pseudo‑sufficient statistic ν₀
78+ effective_sample_size (Number): Pseudo‑sample size n₀ (a non‑negative scalar)
79+ """
80+
5381 effective_suff_stat_value : NumericArray
5482 effective_sample_size : Number
5583
5684 def transform_to_base_parametrization (self ) -> ExponentialFamilyParametrization :
85+ """
86+ Convert hyperparameters to a canonical parametrization.
87+
88+ The resulting parameter vector is [ν₀, n₀] concatenated.
89+ """
5790 return ExponentialFamilyParametrization (
5891 np .append (self .effective_suff_stat_value , self .effective_sample_size )
5992 )
6093
6194
6295class ContinuousExponentialClassFamily (ParametricFamily ):
6396 """
64- Representation of exponential class with density = h(x) * exp(<n(t), T(x)> + A(t)),
65- where canonical parametrization is that, when n = t
66-
67- Usage of this class:
68- - you can use method transform_to_another to replace x to smth else, for example, into
97+ Representation of a continuous exponential family distribution.
98+
99+ The density is given by:
100+ f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
101+
102+ where:
103+ - θ is the natural parameter,
104+ - T(x) is the sufficient statistic vector,
105+ - h(x) is the base measure (the `normalization_constant`),
106+ - A(θ) is the log‑partition function.
107+
108+ This class supports:
109+ - Canonical parametrization (θ) via `ExponentialFamilyParametrization`.
110+ - Conjugate prior families.
111+ - Posterior updates and posterior predictive distributions.
112+ - Transformation of the random variable (change of variable with Jacobian).
113+
114+ The user must supply functions for the log‑partition `log_partition`,
115+ sufficient statistics `sufficient_statistics`,
116+ base measure `normalization_constant`, as well as the support of the distribution,
117+ the natural parameter space and the range of the sufficient statistic.
69118 """
70119
71120 def __init__ (
@@ -83,6 +132,22 @@ def __init__(
83132 support_by_parametrization : SupportArg = None ,
84133 base_score : Callable [[Parametrization , NumericArray ], NumericArray ] | None = None ,
85134 ):
135+ """
136+ Initialize a continuous exponential family distribution.
137+
138+ Args:
139+ log_partition: Function A(θ) – the log‑partition function.
140+ sufficient_statistics: Function T(x) – the sufficient statistic vector.
141+ normalization_constant: Function h(x) – the base measure.
142+ support: Predicate defining the support of the distribution.
143+ parameter_space: Predicate defining the natural parameter space.
144+ sufficient_statistics_values: Predicate defining the range of T(x).
145+ name: Name of the family.
146+ distr_type: Type of distribution or a callable returning it.
147+ distr_parametrizations: List of parametrization names this family supports.
148+ support_by_parametrization: Callable that returns the support given a parametrization.
149+ base_score: Optional base score function.
150+ """
86151 self ._sufficient = sufficient_statistics
87152 self ._log_partition = log_partition
88153 self ._normalization = normalization_constant
@@ -113,6 +178,16 @@ def __init__(
113178
114179 @property
115180 def log_density (self ) -> ParametrizedFunction :
181+ """
182+ Log‑density function for the exponential family.
183+
184+ The function takes a parametrization (must be `ExponentialFamilyParametrization`)
185+ and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support.
186+
187+ Returns:
188+ Callable[[Parametrization, NumberParameter], Number]
189+ """
190+
116191 def log_density_func (parametrization : Parametrization , x : NumberParameter ) -> Number :
117192 parametrization = cast (ExponentialFamilyParametrization , parametrization )
118193 parametrization = parametrization .transform_to_base_parametrization ()
@@ -132,10 +207,28 @@ def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Nu
132207
133208 @property
134209 def density (self ) -> ParametrizedFunction :
210+ """
211+ Density function (exponentiated log‑density).
212+
213+ Returns:
214+ Callable[[Parametrization, NumberParameter], Number]
215+ """
135216 return lambda parametrization , x : np .exp (self .log_density (parametrization , x ))
136217
137218 @property
138219 def conjugate_prior_family (self ) -> ContinuousExponentialClassFamily :
220+ """
221+ Build the conjugate prior family for this exponential family.
222+
223+ The conjugate prior is an exponential family in the natural parameter θ,
224+ with sufficient statistic [θ, A(θ)] and base measure 1. The resulting
225+ family has its own [log_partition, sufficient_statistics, ...] such that
226+ the posterior updates are given by adding the observed sufficient statistics.
227+
228+ Returns:
229+ ContinuousExponentialClassFamily: The conjugate prior family.
230+ """
231+
139232 def conjugate_sufficient (
140233 theta : NumberParameter ,
141234 ) -> NumberParameter :
@@ -193,6 +286,21 @@ def transform(
193286 self ,
194287 transform_function : Callable [[NumberParameter ], NumberParameter ],
195288 ) -> ContinuousExponentialClassFamily :
289+ """
290+ Transform the random variable by a monotonic, differentiable function.
291+
292+ The new density is obtained via the change‑of‑variable formula.
293+ The sufficient statistic becomes T(transform⁻¹(x)) and the base measure
294+ gains the Jacobian factor.
295+
296+ Args:
297+ transform_function: Invertible, differentiable function g(x) such that
298+ y = g(x). Must be defined on the original support.
299+
300+ Returns:
301+ ContinuousExponentialClassFamily: A new family for the transformed variable.
302+ """
303+
196304 def calculate_jacobian (x : NumberParameter ) -> NumberParameter :
197305 if not isinstance (x , Iterable ):
198306 x = np .array ([x ])
@@ -223,38 +331,40 @@ def new_normalization(x: NumberParameter) -> NumberParameter:
223331
224332 @property
225333 def _mean (self ) -> ParametrizedFunction :
334+ """Compute the mean E[X] by numerical integration over the density."""
335+
226336 def mean_func (parametrization : Parametrization , x : Any ) -> Any :
227337 parametrization = cast (ExponentialFamilyParametrization , parametrization )
228338 dimension_size = 1
229339 if hasattr (x , "__len__" ):
230340 dimension_size = len (x )
231341 return nquad (
232- lambda x : ( # type: ignore[arg-type]
233- np .dot (x , self .density (parametrization , x )) if x in self ._support else 0
234- ),
342+ lambda x : np .dot (x , self .density (parametrization , x )) if x in self ._support else 0 ,
235343 [(float ("-inf" ), float ("inf" ))] * dimension_size ,
236344 )[0 ]
237345
238346 return mean_func
239347
240348 @property
241349 def _second_moment (self ) -> ParametrizedFunction :
350+ """Compute the second moment E[X²] by numerical integration."""
351+
242352 def func (parametrization : Parametrization , x : Any ) -> Any :
243353 parametrization = cast (ExponentialFamilyParametrization , parametrization )
244354 dimension_size = 1
245355 if hasattr (x , "__len__" ):
246356 dimension_size = len (x )
247357 return nquad (
248- lambda x : ( # type: ignore[arg-type]
249- x ** 2 * self .density (parametrization , x ) if x in self ._support else 0
250- ),
358+ lambda x : x ** 2 * self .density (parametrization , x ) if x in self ._support else 0 ,
251359 [(float ("-inf" ), float ("inf" ))] * dimension_size ,
252360 )[0 ]
253361
254362 return func
255363
256364 @property
257365 def _var (self ) -> ParametrizedFunction :
366+ """Compute the variance Var[X] = E[X²] - (E[X])²."""
367+
258368 def func (parametrization : Parametrization , x : Any ) -> Any :
259369 parametrization = cast (ExponentialFamilyParametrization , parametrization )
260370 return self ._second_moment (parametrization , x ) - self ._mean (parametrization , x ) ** 2
@@ -264,6 +374,22 @@ def func(parametrization: Parametrization, x: Any) -> Any:
264374 def posterior_hyperparameters (
265375 self , parametrizaiton : ExponentialConjugateHyperparameters , sample : list [Any ]
266376 ) -> ExponentialConjugateHyperparameters :
377+ """
378+ Update the conjugate prior hyperparameters given observed data.
379+
380+ For a conjugate prior with hyperparameters (ν₀, n₀), the posterior
381+ hyperparameters become:
382+ ν = ν₀ + Σ_{i} T(x_i)
383+ n = n₀ + N
384+
385+ Args:
386+ parametrizaiton: Current conjugate hyperparameters.
387+ sample: List of observations (each can be scalar or array).
388+
389+ Returns:
390+ ExponentialConjugateHyperparameters:
391+ Updated hyperparameters after incorporating the sample.
392+ """
267393 posterior_effective_suff_stat_value = parametrizaiton .effective_suff_stat_value
268394 posterior_effective_sample_size = parametrizaiton .effective_sample_size
269395 if hasattr (sample , "__iter__" ) and not isinstance (sample , str ):
@@ -283,6 +409,19 @@ def posterior_hyperparameters(
283409
284410 @property
285411 def posterior_predictive (self ) -> ParametricFamily :
412+ """
413+ Construct the posterior predictive distribution.
414+
415+ For a conjugate prior, the posterior predictive density of a new observation x
416+ given hyperparameters (ν, n) is:
417+ p(x | ν, n) = h(x) * exp( A(ν) - A(ν + T(x)) )
418+ where A(·) is the log‑partition function of the conjugate prior family.
419+
420+ Returns:
421+ ParametricFamily: A family with parametrization `ExponentialConjugateHyperparameters`
422+ and a `pdf` method implementing the posterior predictive density.
423+ """
424+
286425 def conjugate_log_partition (
287426 parametrization : ExponentialConjugateHyperparameters ,
288427 ) -> NumberParameter :
0 commit comments