1+ """
2+ Exponential family distributions in continuous spaces.
3+
4+ This module implements the continuous exponential family of probability distributions,
5+ their conjugate priors, posterior inference, and posterior predictive distributions.
6+ """
7+
18from __future__ import annotations
29
310__author__ = "Vinogradov Ilya"
411__copyright__ = "Copyright (c) 2025 PySATL project"
512__license__ = "SPDX-License-Identifier: MIT"
613
7-
814from collections .abc import Callable , Iterable
915from dataclasses import dataclass
1016from typing import TYPE_CHECKING , Any , cast
3945@dataclass
4046class ExponentialFamilyParametrization (Parametrization ):
4147 """
42- Standard parametrization of Exponential Family.
48+ Standard parametrization of an exponential family distribution.
49+
50+ This parametrization uses the natural (canonical) parameter vector `theta`
51+ The density is expressed as:
52+ f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
53+
54+ Attributes:
55+ theta (NumberParameter): Natural parameter vector (can be a scalar or array)
4356 """
4457
4558 theta : NumberParameter
4659
4760 def transform_to_base_parametrization (self ) -> ExponentialFamilyParametrization :
61+ """Return the base parametrization (identity transform for canonical form)."""
4862 return self
4963
5064
5165@dataclass
5266class ExponentialConjugateHyperparameters (Parametrization ):
67+ """
68+ Hyperparameters for the conjugate prior of an exponential family
69+
70+ For a prior of the form:
71+ p(θ) ∝ exp(ν₀ᵀ T(θ) + n₀ A(θ))
72+ the hyperparameters are:
73+ effective_suff_stat_value = ν₀
74+ effective_sample_size = n₀
75+
76+ Attributes:
77+ effective_suff_stat_value (NumericArray): Pseudo‑sufficient statistic ν₀
78+ effective_sample_size (Number): Pseudo‑sample size n₀ (a non‑negative scalar)
79+ """
80+
5381 effective_suff_stat_value : NumericArray
5482 effective_sample_size : Number
5583
5684 def transform_to_base_parametrization (self ) -> ExponentialFamilyParametrization :
85+ """
86+ Convert hyperparameters to a canonical parametrization.
87+
88+ The resulting parameter vector is [ν₀, n₀] concatenated.
89+ """
5790 return ExponentialFamilyParametrization (
5891 np .append (self .effective_suff_stat_value , self .effective_sample_size )
5992 )
6093
6194
6295class ContinuousExponentialClassFamily (ParametricFamily ):
6396 """
64- Representation of exponential class with density = h(x) * exp(<n(t), T(x)> + A(t)),
65- where canonical parametrization is that, when n = t
66-
67- Usage of this class:
68- - you can use method transform_to_another to replace x to smth else, for example, into
97+ Representation of a continuous exponential family distribution.
98+
99+ The density is given by:
100+ f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
101+
102+ where:
103+ - θ is the natural parameter,
104+ - T(x) is the sufficient statistic vector,
105+ - h(x) is the base measure (the `normalization_constant`),
106+ - A(θ) is the log‑partition function.
107+
108+ This class supports:
109+ - Canonical parametrization (θ) via `ExponentialFamilyParametrization`.
110+ - Conjugate prior families.
111+ - Posterior updates and posterior predictive distributions.
112+ - Transformation of the random variable (change of variable with Jacobian).
113+
114+ The user must supply functions for the log‑partition `log_partition`,
115+ sufficient statistics `sufficient_statistics`,
116+ base measure `normalization_constant`, as well as the support of the distribution,
117+ the natural parameter space and the range of the sufficient statistic.
69118 """
70119
71120 def __init__ (
@@ -83,6 +132,22 @@ def __init__(
83132 support_by_parametrization : SupportArg = None ,
84133 base_score : Callable [[Parametrization , NumericArray ], NumericArray ] | None = None ,
85134 ):
135+ """
136+ Initialize a continuous exponential family distribution.
137+
138+ Args:
139+ log_partition: Function A(θ) – the log‑partition function.
140+ sufficient_statistics: Function T(x) – the sufficient statistic vector.
141+ normalization_constant: Function h(x) – the base measure.
142+ support: Predicate defining the support of the distribution.
143+ parameter_space: Predicate defining the natural parameter space.
144+ sufficient_statistics_values: Predicate defining the range of T(x).
145+ name: Name of the family.
146+ distr_type: Type of distribution or a callable returning it.
147+ distr_parametrizations: List of parametrization names this family supports.
148+ support_by_parametrization: Callable that returns the support given a parametrization.
149+ base_score: Optional base score function.
150+ """
86151 self ._sufficient = sufficient_statistics
87152 self ._log_partition = log_partition
88153 self ._normalization = normalization_constant
@@ -113,6 +178,16 @@ def __init__(
113178
114179 @property
115180 def log_density (self ) -> ParametrizedFunction :
181+ """
182+ Log‑density function for the exponential family.
183+
184+ The function takes a parametrization (must be `ExponentialFamilyParametrization`)
185+ and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support.
186+
187+ Returns:
188+ Callable[[Parametrization, NumberParameter], Number]
189+ """
190+
116191 def log_density_func (parametrization : Parametrization , x : NumberParameter ) -> Number :
117192 parametrization = cast (ExponentialFamilyParametrization , parametrization )
118193 parametrization = parametrization .transform_to_base_parametrization ()
@@ -132,10 +207,28 @@ def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Nu
132207
133208 @property
134209 def density (self ) -> ParametrizedFunction :
210+ """
211+ Density function (exponentiated log‑density).
212+
213+ Returns:
214+ Callable[[Parametrization, NumberParameter], Number]
215+ """
135216 return lambda parametrization , x : np .exp (self .log_density (parametrization , x ))
136217
137218 @property
138219 def conjugate_prior_family (self ) -> ContinuousExponentialClassFamily :
220+ """
221+ Build the conjugate prior family for this exponential family.
222+
223+ The conjugate prior is an exponential family in the natural parameter θ,
224+ with sufficient statistic [θ, A(θ)] and base measure 1. The resulting
225+ family has its own [log_partition, sufficient_statistics, ...] such that
226+ the posterior updates are given by adding the observed sufficient statistics.
227+
228+ Returns:
229+ ContinuousExponentialClassFamily: The conjugate prior family.
230+ """
231+
139232 def conjugate_sufficient (
140233 theta : NumberParameter ,
141234 ) -> NumberParameter :
@@ -193,6 +286,21 @@ def transform(
193286 self ,
194287 transform_function : Callable [[NumberParameter ], NumberParameter ],
195288 ) -> ContinuousExponentialClassFamily :
289+ """
290+ Transform the random variable by a monotonic, differentiable function.
291+
292+ The new density is obtained via the change‑of‑variable formula.
293+ The sufficient statistic becomes T(transform⁻¹(x)) and the base measure
294+ gains the Jacobian factor.
295+
296+ Args:
297+ transform_function: Invertible, differentiable function g(x) such that
298+ y = g(x). Must be defined on the original support.
299+
300+ Returns:
301+ ContinuousExponentialClassFamily: A new family for the transformed variable.
302+ """
303+
196304 def calculate_jacobian (x : NumberParameter ) -> NumberParameter :
197305 if not isinstance (x , Iterable ):
198306 x = np .array ([x ])
@@ -223,6 +331,8 @@ def new_normalization(x: NumberParameter) -> NumberParameter:
223331
224332 @property
225333 def _mean (self ) -> ParametrizedFunction :
334+ """Compute the mean E[X] by numerical integration over the density."""
335+
226336 def mean_func (parametrization : Parametrization , x : Any ) -> Any :
227337 parametrization = cast (ExponentialFamilyParametrization , parametrization )
228338 dimension_size = 1
@@ -239,6 +349,8 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any:
239349
240350 @property
241351 def _second_moment (self ) -> ParametrizedFunction :
352+ """Compute the second moment E[X²] by numerical integration."""
353+
242354 def func (parametrization : Parametrization , x : Any ) -> Any :
243355 parametrization = cast (ExponentialFamilyParametrization , parametrization )
244356 dimension_size = 1
@@ -255,6 +367,8 @@ def func(parametrization: Parametrization, x: Any) -> Any:
255367
256368 @property
257369 def _var (self ) -> ParametrizedFunction :
370+ """Compute the variance Var[X] = E[X²] - (E[X])²."""
371+
258372 def func (parametrization : Parametrization , x : Any ) -> Any :
259373 parametrization = cast (ExponentialFamilyParametrization , parametrization )
260374 return self ._second_moment (parametrization , x ) - self ._mean (parametrization , x ) ** 2
@@ -264,6 +378,22 @@ def func(parametrization: Parametrization, x: Any) -> Any:
264378 def posterior_hyperparameters (
265379 self , parametrizaiton : ExponentialConjugateHyperparameters , sample : list [Any ]
266380 ) -> ExponentialConjugateHyperparameters :
381+ """
382+ Update the conjugate prior hyperparameters given observed data.
383+
384+ For a conjugate prior with hyperparameters (ν₀, n₀), the posterior
385+ hyperparameters become:
386+ ν = ν₀ + Σ_{i} T(x_i)
387+ n = n₀ + N
388+
389+ Args:
390+ parametrizaiton: Current conjugate hyperparameters.
391+ sample: List of observations (each can be scalar or array).
392+
393+ Returns:
394+ ExponentialConjugateHyperparameters:
395+ Updated hyperparameters after incorporating the sample.
396+ """
267397 posterior_effective_suff_stat_value = parametrizaiton .effective_suff_stat_value
268398 posterior_effective_sample_size = parametrizaiton .effective_sample_size
269399 if hasattr (sample , "__iter__" ) and not isinstance (sample , str ):
@@ -283,6 +413,19 @@ def posterior_hyperparameters(
283413
284414 @property
285415 def posterior_predictive (self ) -> ParametricFamily :
416+ """
417+ Construct the posterior predictive distribution.
418+
419+ For a conjugate prior, the posterior predictive density of a new observation x
420+ given hyperparameters (ν, n) is:
421+ p(x | ν, n) = h(x) * exp( A(ν) - A(ν + T(x)) )
422+ where A(·) is the log‑partition function of the conjugate prior family.
423+
424+ Returns:
425+ ParametricFamily: A family with parametrization `ExponentialConjugateHyperparameters`
426+ and a `pdf` method implementing the posterior predictive density.
427+ """
428+
286429 def conjugate_log_partition (
287430 parametrization : ExponentialConjugateHyperparameters ,
288431 ) -> NumberParameter :
0 commit comments