Skip to content

Commit 1d58077

Browse files
committed
docs(exponential): add docstrings
1 parent 73a4ddb commit 1d58077

3 files changed

Lines changed: 156 additions & 14 deletions

File tree

src/pysatl_core/families/exponential_family.py

Lines changed: 152 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
"""
2+
Exponential family distributions in continuous spaces.
3+
4+
This module implements the continuous exponential family of probability distributions,
5+
their conjugate priors, posterior inference, and posterior predictive distributions.
6+
"""
7+
18
from __future__ import annotations
29

310
__author__ = "Vinogradov Ilya"
411
__copyright__ = "Copyright (c) 2025 PySATL project"
512
__license__ = "SPDX-License-Identifier: MIT"
613

7-
814
from collections.abc import Callable, Iterable
915
from dataclasses import dataclass
1016
from typing import TYPE_CHECKING, Any, cast
@@ -39,33 +45,76 @@
3945
@dataclass
4046
class ExponentialFamilyParametrization(Parametrization):
4147
"""
42-
Standard parametrization of Exponential Family.
48+
Standard parametrization of an exponential family distribution.
49+
50+
This parametrization uses the natural (canonical) parameter vector `theta`
51+
The density is expressed as:
52+
f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
53+
54+
Attributes:
55+
theta (NumberParameter): Natural parameter vector (can be a scalar or array)
4356
"""
4457

4558
theta: NumberParameter
4659

4760
def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization:
61+
"""Return the base parametrization (identity transform for canonical form)."""
4862
return self
4963

5064

5165
@dataclass
5266
class ExponentialConjugateHyperparameters(Parametrization):
67+
"""
68+
Hyperparameters for the conjugate prior of an exponential family
69+
70+
For a prior of the form:
71+
p(θ) ∝ exp(ν₀ᵀ T(θ) + n₀ A(θ))
72+
the hyperparameters are:
73+
effective_suff_stat_value = ν₀
74+
effective_sample_size = n₀
75+
76+
Attributes:
77+
effective_suff_stat_value (NumericArray): Pseudo‑sufficient statistic ν₀
78+
effective_sample_size (Number): Pseudo‑sample size n₀ (a non‑negative scalar)
79+
"""
80+
5381
effective_suff_stat_value: NumericArray
5482
effective_sample_size: Number
5583

5684
def transform_to_base_parametrization(self) -> ExponentialFamilyParametrization:
85+
"""
86+
Convert hyperparameters to a canonical parametrization.
87+
88+
The resulting parameter vector is [ν₀, n₀] concatenated.
89+
"""
5790
return ExponentialFamilyParametrization(
5891
np.append(self.effective_suff_stat_value, self.effective_sample_size)
5992
)
6093

6194

6295
class ContinuousExponentialClassFamily(ParametricFamily):
6396
"""
64-
Representation of exponential class with density = h(x) * exp(<n(t), T(x)> + A(t)),
65-
where canonical parametrization is that, when n = t
66-
67-
Usage of this class:
68-
- you can use method transform_to_another to replace x to smth else, for example, into
97+
Representation of a continuous exponential family distribution.
98+
99+
The density is given by:
100+
f(x|θ) = h(x) * exp(θᵀ T(x) - A(θ))
101+
102+
where:
103+
- θ is the natural parameter,
104+
- T(x) is the sufficient statistic vector,
105+
- h(x) is the base measure (the `normalization_constant`),
106+
- A(θ) is the log‑partition function.
107+
108+
This class supports:
109+
- Canonical parametrization (θ) via `ExponentialFamilyParametrization`.
110+
- Conjugate prior families.
111+
- Posterior updates and posterior predictive distributions.
112+
- Transformation of the random variable (change of variable with Jacobian).
113+
114+
The user must supply functions for the log‑partition `log_partition`,
115+
sufficient statistics `sufficient_statistics`,
116+
base measure `normalization_constant`, as well as the support of the distribution,
117+
the natural parameter space and the range of the sufficient statistic.
69118
"""
70119

71120
def __init__(
@@ -83,6 +132,22 @@ def __init__(
83132
support_by_parametrization: SupportArg = None,
84133
base_score: Callable[[Parametrization, NumericArray], NumericArray] | None = None,
85134
):
135+
"""
136+
Initialize a continuous exponential family distribution.
137+
138+
Args:
139+
log_partition: Function A(θ) – the log‑partition function.
140+
sufficient_statistics: Function T(x) – the sufficient statistic vector.
141+
normalization_constant: Function h(x) – the base measure.
142+
support: Predicate defining the support of the distribution.
143+
parameter_space: Predicate defining the natural parameter space.
144+
sufficient_statistics_values: Predicate defining the range of T(x).
145+
name: Name of the family.
146+
distr_type: Type of distribution or a callable returning it.
147+
distr_parametrizations: List of parametrization names this family supports.
148+
support_by_parametrization: Callable that returns the support given a parametrization.
149+
base_score: Optional base score function.
150+
"""
86151
self._sufficient = sufficient_statistics
87152
self._log_partition = log_partition
88153
self._normalization = normalization_constant
@@ -113,6 +178,16 @@ def __init__(
113178

114179
@property
115180
def log_density(self) -> ParametrizedFunction:
181+
"""
182+
Log‑density function for the exponential family.
183+
184+
The function takes a parametrization (must be `ExponentialFamilyParametrization`)
185+
and a point `x`, and returns log f(x|θ). Returns -inf for x outside the support.
186+
187+
Returns:
188+
Callable[[Parametrization, NumberParameter], Number]
189+
"""
190+
116191
def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Number:
117192
parametrization = cast(ExponentialFamilyParametrization, parametrization)
118193
parametrization = parametrization.transform_to_base_parametrization()
@@ -132,10 +207,28 @@ def log_density_func(parametrization: Parametrization, x: NumberParameter) -> Nu
132207

133208
@property
134209
def density(self) -> ParametrizedFunction:
210+
"""
211+
Density function (exponentiated log‑density).
212+
213+
Returns:
214+
Callable[[Parametrization, NumberParameter], Number]
215+
"""
135216
return lambda parametrization, x: np.exp(self.log_density(parametrization, x))
136217

137218
@property
138219
def conjugate_prior_family(self) -> ContinuousExponentialClassFamily:
220+
"""
221+
Build the conjugate prior family for this exponential family.
222+
223+
The conjugate prior is an exponential family in the natural parameter θ,
224+
with sufficient statistic [θ, A(θ)] and base measure 1. The resulting
225+
family has its own [log_partition, sufficient_statistics, ...] such that
226+
the posterior updates are given by adding the observed sufficient statistics.
227+
228+
Returns:
229+
ContinuousExponentialClassFamily: The conjugate prior family.
230+
"""
231+
139232
def conjugate_sufficient(
140233
theta: NumberParameter,
141234
) -> NumberParameter:
@@ -193,6 +286,21 @@ def transform(
193286
self,
194287
transform_function: Callable[[NumberParameter], NumberParameter],
195288
) -> ContinuousExponentialClassFamily:
289+
"""
290+
Transform the random variable by a monotonic, differentiable function.
291+
292+
The new density is obtained via the change‑of‑variable formula.
293+
The sufficient statistic becomes T(transform⁻¹(x)) and the base measure
294+
gains the Jacobian factor.
295+
296+
Args:
297+
transform_function: Invertible, differentiable function g(x) such that
298+
y = g(x). Must be defined on the original support.
299+
300+
Returns:
301+
ContinuousExponentialClassFamily: A new family for the transformed variable.
302+
"""
303+
196304
def calculate_jacobian(x: NumberParameter) -> NumberParameter:
197305
if not isinstance(x, Iterable):
198306
x = np.array([x])
@@ -223,38 +331,40 @@ def new_normalization(x: NumberParameter) -> NumberParameter:
223331

224332
@property
225333
def _mean(self) -> ParametrizedFunction:
334+
"""Compute the mean E[X] by numerical integration over the density."""
335+
226336
def mean_func(parametrization: Parametrization, x: Any) -> Any:
227337
parametrization = cast(ExponentialFamilyParametrization, parametrization)
228338
dimension_size = 1
229339
if hasattr(x, "__len__"):
230340
dimension_size = len(x)
231341
return nquad(
232-
lambda x: ( # type: ignore[arg-type]
233-
np.dot(x, self.density(parametrization, x)) if x in self._support else 0
234-
),
342+
lambda x: np.dot(x, self.density(parametrization, x)) if x in self._support else 0,
235343
[(float("-inf"), float("inf"))] * dimension_size,
236344
)[0]
237345

238346
return mean_func
239347

240348
@property
241349
def _second_moment(self) -> ParametrizedFunction:
350+
"""Compute the second moment E[X²] by numerical integration."""
351+
242352
def func(parametrization: Parametrization, x: Any) -> Any:
243353
parametrization = cast(ExponentialFamilyParametrization, parametrization)
244354
dimension_size = 1
245355
if hasattr(x, "__len__"):
246356
dimension_size = len(x)
247357
return nquad(
248-
lambda x: ( # type: ignore[arg-type]
249-
x**2 * self.density(parametrization, x) if x in self._support else 0
250-
),
358+
lambda x: x**2 * self.density(parametrization, x) if x in self._support else 0,
251359
[(float("-inf"), float("inf"))] * dimension_size,
252360
)[0]
253361

254362
return func
255363

256364
@property
257365
def _var(self) -> ParametrizedFunction:
366+
"""Compute the variance Var[X] = E[X²] - (E[X])²."""
367+
258368
def func(parametrization: Parametrization, x: Any) -> Any:
259369
parametrization = cast(ExponentialFamilyParametrization, parametrization)
260370
return self._second_moment(parametrization, x) - self._mean(parametrization, x) ** 2
@@ -264,6 +374,22 @@ def func(parametrization: Parametrization, x: Any) -> Any:
264374
def posterior_hyperparameters(
265375
self, parametrizaiton: ExponentialConjugateHyperparameters, sample: list[Any]
266376
) -> ExponentialConjugateHyperparameters:
377+
"""
378+
Update the conjugate prior hyperparameters given observed data.
379+
380+
For a conjugate prior with hyperparameters (ν₀, n₀), the posterior
381+
hyperparameters become:
382+
ν = ν₀ + Σ_{i} T(x_i)
383+
n = n₀ + N
384+
385+
Args:
386+
parametrizaiton: Current conjugate hyperparameters.
387+
sample: List of observations (each can be scalar or array).
388+
389+
Returns:
390+
ExponentialConjugateHyperparameters:
391+
Updated hyperparameters after incorporating the sample.
392+
"""
267393
posterior_effective_suff_stat_value = parametrizaiton.effective_suff_stat_value
268394
posterior_effective_sample_size = parametrizaiton.effective_sample_size
269395
if hasattr(sample, "__iter__") and not isinstance(sample, str):
@@ -283,6 +409,19 @@ def posterior_hyperparameters(
283409

284410
@property
285411
def posterior_predictive(self) -> ParametricFamily:
412+
"""
413+
Construct the posterior predictive distribution.
414+
415+
For a conjugate prior, the posterior predictive density of a new observation x
416+
given hyperparameters (ν, n) is:
417+
p(x | ν, n) = h(x) * exp( A(ν) - A(ν + T(x)) )
418+
where A(·) is the log‑partition function of the conjugate prior family.
419+
420+
Returns:
421+
ParametricFamily: A family with parametrization `ExponentialConjugateHyperparameters`
422+
and a `pdf` method implementing the posterior predictive density.
423+
"""
424+
286425
def conjugate_log_partition(
287426
parametrization: ExponentialConjugateHyperparameters,
288427
) -> NumberParameter:

src/pysatl_core/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,9 @@ def contains(self, x: Number | NumericArray) -> bool | BoolArray:
268268
if not hasattr(x, "__iter__"):
269269
x = np.array([x])
270270

271+
x = np.array(x)
272+
assert len(x) == len(self.intervals)
273+
271274
return all(
272275
x_coordinate in interval
273276
for interval, x_coordinate in zip(self.intervals, x, strict=True)

tests/unit/families/test_exponential_family.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__author__ = "Leonid Elkin"
1+
__author__ = "Vinogradov Ilya"
22
__copyright__ = "Copyright (c) 2025 PySATL project"
33
__license__ = "SPDX-License-Identifier: MIT"
44

0 commit comments

Comments
 (0)