Skip to content

Commit f67f3fd

Browse files
author
domosedy
committed
docs(exponential): add docstrings
1 parent 41d0a3b commit f67f3fd

1 file changed

Lines changed: 150 additions & 7 deletions

File tree

src/pysatl_core/families/exponential_family.py

Lines changed: 150 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
"""
2+
Exponential family distributions in continuous spaces.
3+
4+
This module implements the continuous exponential family of probability distributions,
5+
their conjugate priors, posterior inference, and posterior predictive distributions.
6+
"""
7+
18
from __future__ import annotations
29

310
__author__ = "Vinogradov Ilya"
411
__copyright__ = "Copyright (c) 2025 PySATL project"
512
__license__ = "SPDX-License-Identifier: MIT"
613

7-
814
from collections.abc import Callable, Iterable
915
from dataclasses import dataclass
1016
from typing import TYPE_CHECKING, Any, cast
@@ -39,33 +45,76 @@
3945
@dataclass
4046
class ExponentialFamilyParametrization(Parametrization):
4147
"""
42-
Standard parametrization of Exponential Family.
48+
Standard parametrization of an exponential family distribution.
49+
50+
This parametrization uses the natural (canonical) parameter vector `theta`
51+
The density is expressed as:
52+
f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
53+
54+
Attributes:
55+
theta (NumberParameter): Natural parameter vector (can be a scalar or array)
4356
"""
4457

4558
theta: NumberParameter
4659

4760
def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization:
61+
"""Return the base parametrization (identity transform for canonical form)."""
4862
return self
4963

5064

5165
@dataclass
5266
class ExponentialConjugateHyperparameters(Parametrization):
67+
"""
68+
Hyperparameters for the conjugate prior of an exponential family
69+
70+
For a prior of the form:
71+
p(θ) ∝ exp(ν₀ᵀ T(θ) + n₀ A(θ))
72+
the hyperparameters are:
73+
effective_suff_stat_value = ν₀
74+
effective_sample_size = n₀
75+
76+
Attributes:
77+
effective_suff_stat_value (NumericArray): Pseudo‑sufficient statistic ν₀
78+
effective_sample_size (Number): Pseudo‑sample size n₀ (a non‑negative scalar)
79+
"""
80+
5381
effective_suff_stat_value: NumericArray
5482
effective_sample_size: Number
5583

5684
def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization:
85+
"""
86+
Convert hyperparameters to a canonical parametrization.
87+
88+
The resulting parameter vector is [ν₀, n₀] concatenated.
89+
"""
5790
return ExponentialFamilyParametrization(
5891
np.append(self.effective_suff_stat_value, self.effective_sample_size)
5992
)
6093

6194

6295
class ContinuousExponentialClassFamily(ParametricFamily):
6396
"""
64-
Representation of exponential class with density = h(x) * exp(<n(t), T(x)> + A(t)),
65-
where canonical parametrization is that, when n = t
66-
67-
Usage of this class:
68-
- you can use method transform_to_another to replace x to smth else, for example, into
97+
Representation of a continuous exponential family distribution.
98+
99+
The density is given by:
100+
f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
101+
102+
where:
103+
- θ is the natural parameter,
104+
- T(x) is the sufficient statistic vector,
105+
- h(x) is the base measure (the `normalization_constant`),
106+
- A(θ) is the log‑partition function.
107+
108+
This class supports:
109+
- Canonical parametrization (θ) via `ExponentialFamilyParametrization`.
110+
- Conjugate prior families.
111+
- Posterior updates and posterior predictive distributions.
112+
- Transformation of the random variable (change of variable with Jacobian).
113+
114+
The user must supply functions for the log‑partition `log_partition`,
115+
sufficient statistics `sufficient_statistics`,
116+
base measure `normalization_constant`, as well as the support of the distribution,
117+
the natural parameter space and the range of the sufficient statistic.
69118
"""
70119

71120
def __init__(
@@ -83,6 +132,22 @@ def __init__(
83132
support_by_parametrization: SupportArg = None,
84133
base_score: Callable[[Parametrization, NumericArray], NumericArray] | None = None,
85134
):
135+
"""
136+
Initialize a continuous exponential family distribution.
137+
138+
Args:
139+
log_partition: Function A(θ) – the log‑partition function.
140+
sufficient_statistics: Function T(x) – the sufficient statistic vector.
141+
normalization_constant: Function h(x) – the base measure.
142+
support: Predicate defining the support of the distribution.
143+
parameter_space: Predicate defining the natural parameter space.
144+
sufficient_statistics_values: Predicate defining the range of T(x).
145+
name: Name of the family.
146+
distr_type: Type of distribution or a callable returning it.
147+
distr_parametrizations: List of parametrization names this family supports.
148+
support_by_parametrization: Callable that returns the support given a parametrization.
149+
base_score: Optional base score function.
150+
"""
86151
self._sufficient = sufficient_statistics
87152
self._log_partition = log_partition
88153
self._normalization = normalization_constant
@@ -113,6 +178,16 @@ def __init__(
113178

114179
@property
115180
def log_density(self) -> ParametrizedFunction:
181+
"""
182+
Log‑density function for the exponential family.
183+
184+
The function takes a parametrization (must be `ExponentialFamilyParametrization`)
185+
and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support.
186+
187+
Returns:
188+
Callable[[Parametrization, NumberParameter], Number]
189+
"""
190+
116191
def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Number:
117192
parametrization = cast(ExponentialFamilyParametrization, parametrization)
118193
parametrization = parametrization.transform_to_base_parametrization()
@@ -132,10 +207,28 @@ def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Nu
132207

133208
@property
134209
def density(self) -> ParametrizedFunction:
210+
"""
211+
Density function (exponentiated log‑density).
212+
213+
Returns:
214+
Callable[[Parametrization, NumberParameter], Number]
215+
"""
135216
return lambda parametrization, x: np.exp(self.log_density(parametrization, x))
136217

137218
@property
138219
def conjugate_prior_family(self) -> ContinuousExponentialClassFamily:
220+
"""
221+
Build the conjugate prior family for this exponential family.
222+
223+
The conjugate prior is an exponential family in the natural parameter θ,
224+
with sufficient statistic [θ, A(θ)] and base measure 1. The resulting
225+
family has its own [log_partition, sufficient_statistics, ...] such that
226+
the posterior updates are given by adding the observed sufficient statistics.
227+
228+
Returns:
229+
ContinuousExponentialClassFamily: The conjugate prior family.
230+
"""
231+
139232
def conjugate_sufficient(
140233
theta: NumberParameter,
141234
) -> NumberParameter:
@@ -193,6 +286,21 @@ def transform(
193286
self,
194287
transform_function: Callable[[NumberParameter], NumberParameter],
195288
) -> ContinuousExponentialClassFamily:
289+
"""
290+
Transform the random variable by a monotonic, differentiable function.
291+
292+
The new density is obtained via the change‑of‑variable formula.
293+
The sufficient statistic becomes T(transform⁻¹(x)) and the base measure
294+
gains the Jacobian factor.
295+
296+
Args:
297+
transform_function: Invertible, differentiable function g(x) such that
298+
y = g(x). Must be defined on the original support.
299+
300+
Returns:
301+
ContinuousExponentialClassFamily: A new family for the transformed variable.
302+
"""
303+
196304
def calculate_jacobian(x: NumberParameter) -> NumberParameter:
197305
if not isinstance(x, Iterable):
198306
x = np.array([x])
@@ -223,6 +331,8 @@ def new_normalization(x: NumberParameter) -> NumberParameter:
223331

224332
@property
225333
def _mean(self) -> ParametrizedFunction:
334+
"""Compute the mean E[X] by numerical integration over the density."""
335+
226336
def mean_func(parametrization: Parametrization, x: Any) -> Any:
227337
parametrization = cast(ExponentialFamilyParametrization, parametrization)
228338
dimension_size = 1
@@ -239,6 +349,8 @@ def mean_func(parametrization: Parametrization, x: Any) -> Any:
239349

240350
@property
241351
def _second_moment(self) -> ParametrizedFunction:
352+
"""Compute the second moment E[X²] by numerical integration."""
353+
242354
def func(parametrization: Parametrization, x: Any) -> Any:
243355
parametrization = cast(ExponentialFamilyParametrization, parametrization)
244356
dimension_size = 1
@@ -255,6 +367,8 @@ def func(parametrization: Parametrization, x: Any) -> Any:
255367

256368
@property
257369
def _var(self) -> ParametrizedFunction:
370+
"""Compute the variance Var[X] = E[X²] - (E[X])²."""
371+
258372
def func(parametrization: Parametrization, x: Any) -> Any:
259373
parametrization = cast(ExponentialFamilyParametrization, parametrization)
260374
return self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2
@@ -264,6 +378,22 @@ def func(parametrization: Parametrization, x: Any) -> Any:
264378
def posterior_hyperparameters(
265379
self, parametrizaiton: ExponentialConjugateHyperparameters, sample: list[Any]
266380
) -> ExponentialConjugateHyperparameters:
381+
"""
382+
Update the conjugate prior hyperparameters given observed data.
383+
384+
For a conjugate prior with hyperparameters (ν₀, n₀), the posterior
385+
hyperparameters become:
386+
ν = ν₀ + Σ_{i} T(x_i)
387+
n = n₀ + N
388+
389+
Args:
390+
parametrizaiton: Current conjugate hyperparameters.
391+
sample: List of observations (each can be scalar or array).
392+
393+
Returns:
394+
ExponentialConjugateHyperparameters:
395+
Updated hyperparameters after incorporating the sample.
396+
"""
267397
posterior_effective_suff_stat_value = parametrizaiton.effective_suff_stat_value
268398
posterior_effective_sample_size = parametrizaiton.effective_sample_size
269399
if hasattr(sample, "__iter__") and not isinstance(sample, str):
@@ -283,6 +413,19 @@ def posterior_hyperparameters(
283413

284414
@property
285415
def posterior_predictive(self) -> ParametricFamily:
416+
"""
417+
Construct the posterior predictive distribution.
418+
419+
For a conjugate prior, the posterior predictive density of a new observation x
420+
given hyperparameters (ν, n) is:
421+
p(x | ν, n) = h(x) * exp( A(ν) - A(ν + T(x)) )
422+
where A(·) is the log‑partition function of the conjugate prior family.
423+
424+
Returns:
425+
ParametricFamily: A family with parametrization `ExponentialConjugateHyperparameters`
426+
and a `pdf` method implementing the posterior predictive density.
427+
"""
428+
286429
def conjugate_log_partition(
287430
parametrization: ExponentialConjugateHyperparameters,
288431
) -> NumberParameter:

0 commit comments

Comments
 (0)