Skip to content

Commit 6b68d73

Browse files
MetbcyQazalbash
andauthored
doc(gh-2187): docstrings for Gumbel, Laplace, Logistic (#2202)
* doc(gh-2187): enhance documentation for Gumbel, Laplace, and Logistic distributions Adds math-flavored docstrings to three more continuous distributions on the #2187 checklist, following the template from #2188, #2192, and #2199: - Gumbel: class, __init__, sample, log_prob, mean, variance, cdf, icdf. - Laplace: class, __init__, sample, log_prob, mean, variance, cdf, icdf, entropy. - Logistic: class, __init__, sample, log_prob, mean, variance, cdf, icdf, entropy. Docstrings only, no logic changes. ruff check and ruff format --check are clean. * fix: small fixes --------- Co-authored-by: Meesum Qazalbash <meesumqazalbash@gmail.com>
1 parent be26d95 commit 6b68d73

1 file changed

Lines changed: 230 additions & 1 deletion

File tree

numpyro/distributions/continuous.py

Lines changed: 230 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def variance(self) -> ArrayLike:
841841
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then
842842
843843
.. math::
844-
\mathrm{Var}[X] = \frac{\alpha}{\lambda^2}
844+
\mathrm{Var}(X) = \frac{\alpha}{\lambda^2}
845845
"""
846846
return self.concentration / jnp.power(self.rate, 2)
847847

@@ -1306,6 +1306,22 @@ def mean(self) -> ArrayLike:
13061306

13071307

13081308
class Gumbel(Distribution):
1309+
r"""The Gumbel (maximum) distribution, a continuous real-valued
1310+
distribution parameterized by location :math:`\mu` and scale :math:`\beta > 0`.
1311+
It is the limiting distribution of the maximum of a large number of i.i.d.
1312+
samples from an exponential-tailed distribution.
1313+
1314+
The Probability Density Function (PDF) is:
1315+
1316+
.. math::
1317+
f(x \mid \mu, \beta) = \frac{1}{\beta} \exp\!\left(
1318+
-\frac{x - \mu}{\beta} - \exp\!\left(-\frac{x - \mu}{\beta}\right)
1319+
\right), \quad x \in \mathbb{R}
1320+
1321+
where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and
1322+
:math:`\beta > 0` is the scale (:attr:`scale`).
1323+
"""
1324+
13091325
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
13101326
support = constraints.real
13111327
reparametrized_params = ["loc", "scale"]
@@ -1317,6 +1333,11 @@ def __init__(
13171333
*,
13181334
validate_args: Optional[bool] = None,
13191335
) -> None:
1336+
r"""
1337+
:param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``.
1338+
:param scale: Scale parameter :math:`\beta > 0`. Defaults to ``1.0``.
1339+
:param validate_args: If True, enforce domain constraints during initialization.
1340+
"""
13201341
self.loc, self.scale = promote_shapes(loc, scale)
13211342
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
13221343

@@ -1325,6 +1346,15 @@ def __init__(
13251346
)
13261347

13271348
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
1349+
r"""Draw samples from the Gumbel distribution via the location-scale
1350+
transform :math:`X = \mu + \beta Z`, where
1351+
:math:`Z \sim \mathrm{Gumbel}(0, 1)` is drawn from
1352+
:func:`~jax.random.gumbel`.
1353+
1354+
:param key: A JAX PRNG key.
1355+
:param sample_shape: Sample dimensions to prepend to the batch shape.
1356+
:return: Real-valued samples from the Gumbel distribution.
1357+
"""
13281358
assert is_prng_key(key)
13291359
standard_gumbel_sample = random.gumbel(
13301360
key, shape=sample_shape + self.batch_shape + self.event_shape
@@ -1333,23 +1363,63 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
13331363

13341364
@validate_sample
13351365
def log_prob(self, value: ArrayLike) -> ArrayLike:
1366+
r"""Evaluate the log probability density function at ``value``.
1367+
1368+
Letting :math:`z = (x - \mu)/\beta`,
1369+
1370+
.. math::
1371+
\ln f(x \mid \mu, \beta) = -z - e^{-z} - \ln \beta
1372+
1373+
:param value: Real-valued point :math:`x` at which to evaluate the log PDF.
1374+
:return: Log probability density evaluated under the Gumbel distribution.
1375+
"""
13361376
z = (value - self.loc) / self.scale
13371377
return -(z + jnp.exp(-z)) - jnp.log(self.scale)
13381378

13391379
@property
13401380
def mean(self) -> ArrayLike:
1381+
r"""Mean of the Gumbel distribution:
1382+
1383+
.. math::
1384+
\mathbb{E}[X] = \mu + \beta \gamma
1385+
1386+
where :math:`\gamma \approx 0.5772\ldots` is the Euler-Mascheroni constant,
1387+
available at, :data:`~jax.numpy.euler_gamma`.
1388+
"""
13411389
return jnp.broadcast_to(
13421390
self.loc + self.scale * jnp.euler_gamma, self.batch_shape
13431391
)
13441392

13451393
@property
13461394
def variance(self) -> ArrayLike:
1395+
r"""Variance of the Gumbel distribution:
1396+
1397+
.. math::
1398+
\mathrm{Var}(X) = \frac{\pi^2}{6} \beta^2
1399+
"""
13471400
return jnp.broadcast_to(jnp.pi**2 / 6.0 * self.scale**2, self.batch_shape)
13481401

13491402
def cdf(self, value: ArrayLike) -> ArrayLike:
1403+
r"""Cumulative Distribution Function (CDF) of the Gumbel distribution:
1404+
1405+
.. math::
1406+
F(x \mid \mu, \beta) = \exp\!\left(-\exp\!\left(-\frac{x - \mu}{\beta}\right)\right)
1407+
1408+
:param value: Real-valued point :math:`x` at which to evaluate the CDF.
1409+
:return: CDF values in :math:`[0, 1]`.
1410+
"""
13501411
return jnp.exp(-jnp.exp((self.loc - value) / self.scale))
13511412

13521413
def icdf(self, q: ArrayLike) -> ArrayLike:
1414+
r"""Inverse CDF (quantile function) of the Gumbel distribution:
1415+
1416+
.. math::
1417+
F^{-1}(q \mid \mu, \beta) = \mu - \beta \ln(-\ln q),
1418+
\quad q \in (0, 1)
1419+
1420+
:param q: Quantile values in :math:`(0, 1)`.
1421+
:return: Real-valued quantiles of the Gumbel distribution at ``q``.
1422+
"""
13531423
return self.loc - self.scale * jnp.log(-jnp.log(q))
13541424

13551425

@@ -1415,6 +1485,22 @@ def variance(self) -> ArrayLike:
14151485

14161486

14171487
class Laplace(Distribution):
1488+
r"""The Laplace (double-exponential) distribution, a continuous real-valued
1489+
distribution parameterized by location :math:`\mu` and scale :math:`b > 0`.
1490+
It is the distribution of the difference of two i.i.d. exponential variates
1491+
and has heavier tails than the Normal distribution.
1492+
1493+
The Probability Density Function (PDF) is:
1494+
1495+
.. math::
1496+
f(x \mid \mu, b) = \frac{1}{2 b}
1497+
\exp\!\left(-\frac{|x - \mu|}{b}\right),
1498+
\quad x \in \mathbb{R}
1499+
1500+
where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and
1501+
:math:`b > 0` is the scale (:attr:`scale`).
1502+
"""
1503+
14181504
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
14191505
support = constraints.real
14201506
reparametrized_params = ["loc", "scale"]
@@ -1426,13 +1512,26 @@ def __init__(
14261512
*,
14271513
validate_args: Optional[bool] = None,
14281514
) -> None:
1515+
r"""
1516+
:param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``.
1517+
:param scale: Scale parameter :math:`b > 0`. Defaults to ``1.0``.
1518+
:param validate_args: If True, enforce domain constraints during initialization.
1519+
"""
14291520
self.loc, self.scale = promote_shapes(loc, scale)
14301521
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
14311522
super(Laplace, self).__init__(
14321523
batch_shape=batch_shape, validate_args=validate_args
14331524
)
14341525

14351526
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
1527+
r"""Draw samples via the location-scale transform
1528+
:math:`X = \mu + b Z`, where :math:`Z \sim \mathrm{Laplace}(0, 1)` is
1529+
drawn from :func:`~jax.random.laplace`.
1530+
1531+
:param key: A JAX PRNG key.
1532+
:param sample_shape: Sample dimensions to prepend to the batch shape.
1533+
:return: Real-valued samples from the Laplace distribution.
1534+
"""
14361535
assert is_prng_key(key)
14371536
eps = random.laplace(
14381537
key, shape=sample_shape + self.batch_shape + self.event_shape
@@ -1441,27 +1540,73 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
14411540

14421541
@validate_sample
14431542
def log_prob(self, value: ArrayLike) -> ArrayLike:
1543+
r"""Evaluate the log probability density function at ``value``:
1544+
1545+
.. math::
1546+
\ln f(x \mid \mu, b) = -\frac{|x - \mu|}{b} - \ln(2 b)
1547+
1548+
:param value: Real-valued point :math:`x` at which to evaluate the log PDF.
1549+
:return: Log probability density evaluated under the Laplace distribution.
1550+
"""
14441551
normalize_term = jnp.log(2 * self.scale)
14451552
value_scaled = jnp.abs(value - self.loc) / self.scale
14461553
return -value_scaled - normalize_term
14471554

14481555
@property
14491556
def mean(self) -> ArrayLike:
1557+
r"""Mean of the Laplace distribution:
1558+
1559+
.. math::
1560+
\mathbb{E}[X] = \mu
1561+
"""
14501562
return jnp.broadcast_to(self.loc, self.batch_shape)
14511563

14521564
@property
14531565
def variance(self) -> ArrayLike:
1566+
r"""Variance of the Laplace distribution:
1567+
1568+
.. math::
1569+
\mathrm{Var}(X) = 2 b^2
1570+
"""
14541571
return jnp.broadcast_to(2 * self.scale**2, self.batch_shape)
14551572

14561573
def cdf(self, value: ArrayLike) -> ArrayLike:
1574+
r"""Cumulative Distribution Function (CDF) of the Laplace distribution.
1575+
Letting :math:`z = (x - \mu)/b`,
1576+
1577+
.. math::
1578+
F(x \mid \mu, b) = \frac{1}{2} - \frac{1}{2}\,
1579+
\operatorname{sgn}(z)\,\left(e^{-|z|} - 1\right)
1580+
1581+
The implementation uses :func:`~jax.numpy.expm1` for numerical
1582+
stability near :math:`z = 0`.
1583+
1584+
:param value: Real-valued point :math:`x` at which to evaluate the CDF.
1585+
:return: CDF values in :math:`[0, 1]`.
1586+
"""
14571587
scaled = (value - self.loc) / self.scale
14581588
return 0.5 - 0.5 * jnp.sign(scaled) * jnp.expm1(-jnp.abs(scaled))
14591589

14601590
def icdf(self, q: ArrayLike) -> ArrayLike:
1591+
r"""Inverse CDF (quantile function) of the Laplace distribution:
1592+
1593+
.. math::
1594+
F^{-1}(q \mid \mu, b) = \mu - b\,\mathrm{sgn}\left(q - \frac{1}{2}\right)\,
1595+
\ln\!\left(1 - 2 \left| q - \frac{1}{2} \right| \right),
1596+
\quad q \in [0, 1]
1597+
1598+
:param q: Quantile values in :math:`[0, 1]`.
1599+
:return: Real-valued quantiles of the Laplace distribution at ``q``.
1600+
"""
14611601
a = q - 0.5
14621602
return self.loc - self.scale * jnp.sign(a) * jnp.log1p(-2 * jnp.abs(a))
14631603

14641604
def entropy(self) -> ArrayLike:
1605+
r"""Differential entropy of the Laplace distribution:
1606+
1607+
.. math::
1608+
H(X) = \ln(2 b) + 1
1609+
"""
14651610
return jnp.log(2 * self.scale) + 1
14661611

14671612

@@ -1782,6 +1927,25 @@ def entropy(self) -> ArrayLike:
17821927

17831928

17841929
class Logistic(Distribution):
1930+
r"""The Logistic distribution, a continuous real-valued distribution
1931+
parameterized by location :math:`\mu` and scale :math:`s > 0`. Its CDF is
1932+
the standard logistic (sigmoid) function shifted and scaled to :math:`\mu`,
1933+
:math:`s`, which makes it the natural latent distribution underlying
1934+
logistic regression.
1935+
1936+
The Probability Density Function (PDF) is:
1937+
1938+
.. math::
1939+
f(x \mid \mu, s) = \frac{
1940+
\exp\!\left(-\displaystyle\frac{x - \mu}{s}\right)
1941+
}{
1942+
s \left(1 + \exp\!\left(-\displaystyle\frac{x - \mu}{s}\right)\right)^{2}
1943+
}, \quad x \in \mathbb{R}
1944+
1945+
where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and
1946+
:math:`s > 0` is the scale (:attr:`scale`).
1947+
"""
1948+
17851949
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
17861950
support = constraints.real
17871951
reparametrized_params = ["loc", "scale"]
@@ -1793,11 +1957,24 @@ def __init__(
17931957
*,
17941958
validate_args: Optional[bool] = None,
17951959
) -> None:
1960+
r"""
1961+
:param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``.
1962+
:param scale: Scale parameter :math:`s > 0`. Defaults to ``1.0``.
1963+
:param validate_args: If True, enforce domain constraints during initialization.
1964+
"""
17961965
self.loc, self.scale = promote_shapes(loc, scale)
17971966
batch_shape = lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
17981967
super(Logistic, self).__init__(batch_shape, validate_args=validate_args)
17991968

18001969
def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLike:
1970+
r"""Draw samples via the location-scale transform
1971+
:math:`X = \mu + s Z`, where :math:`Z \sim \mathrm{Logistic}(0, 1)` is
1972+
drawn from :func:`~jax.random.logistic`.
1973+
1974+
:param key: A JAX PRNG key.
1975+
:param sample_shape: Sample dimensions to prepend to the batch shape.
1976+
:return: Real-valued samples from the Logistic distribution.
1977+
"""
18011978
assert is_prng_key(key)
18021979
z = random.logistic(
18031980
key, shape=sample_shape + self.batch_shape + self.event_shape
@@ -1806,27 +1983,79 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
18061983

18071984
@validate_sample
18081985
def log_prob(self, value: ArrayLike) -> ArrayLike:
1986+
r"""Evaluate the log probability density function at ``value``.
1987+
1988+
Letting :math:`u = (\mu - x)/s`, the log PDF is
1989+
1990+
.. math::
1991+
\ln f(x \mid \mu, s) = u - \ln s - 2 \ln(1 + e^{u})
1992+
1993+
The implementation uses :func:`~jax.nn.softplus` for
1994+
:math:`\ln(1 + e^{u})`, which is numerically stable for both large
1995+
positive and large negative values of :math:`u`.
1996+
1997+
:param value: Real-valued point :math:`x` at which to evaluate the log PDF.
1998+
:return: Log probability density evaluated under the Logistic distribution.
1999+
"""
18092000
log_exponent = (self.loc - value) / self.scale
18102001
log_denominator = jnp.log(self.scale) + 2 * nn.softplus(log_exponent)
18112002
return log_exponent - log_denominator
18122003

18132004
@property
18142005
def mean(self) -> ArrayLike:
2006+
r"""Mean of the Logistic distribution:
2007+
2008+
.. math::
2009+
\mathbb{E}[X] = \mu
2010+
"""
18152011
return jnp.broadcast_to(self.loc, self.batch_shape)
18162012

18172013
@property
18182014
def variance(self) -> ArrayLike:
2015+
r"""Variance of the Logistic distribution:
2016+
2017+
.. math::
2018+
\mathrm{Var}(X) = \frac{\pi^2 s^2}{3}
2019+
"""
18192020
var = (self.scale**2) * (jnp.pi**2) / 3
18202021
return jnp.broadcast_to(var, self.batch_shape)
18212022

18222023
def cdf(self, value: ArrayLike) -> ArrayLike:
2024+
r"""Cumulative Distribution Function (CDF) of the Logistic distribution.
2025+
Letting :math:`z = (x - \mu)/s`,
2026+
2027+
.. math::
2028+
F(x \mid \mu, s) = \sigma(z) = \frac{1}{1 + e^{-z}}
2029+
2030+
where :math:`\sigma` is the logistic sigmoid, computed via
2031+
:func:`~jax.scipy.special.expit`.
2032+
2033+
:param value: Real-valued point :math:`x` at which to evaluate the CDF.
2034+
:return: CDF values in :math:`[0, 1]`.
2035+
"""
18232036
scaled = (value - self.loc) / self.scale
18242037
return expit(scaled)
18252038

18262039
def icdf(self, q: ArrayLike) -> ArrayLike:
2040+
r"""Inverse CDF (quantile function) of the Logistic distribution:
2041+
2042+
.. math::
2043+
F^{-1}(q \mid \mu, s) = \mu + s\,\operatorname{logit}(q),
2044+
\quad q \in [0, 1]
2045+
2046+
where :math:`\operatorname{logit}(q) = \ln(q / (1 - q))`.
2047+
2048+
:param q: Quantile values in :math:`[0, 1]`.
2049+
:return: Real-valued quantiles of the Logistic distribution at ``q``.
2050+
"""
18272051
return self.loc + self.scale * logit(q)
18282052

18292053
def entropy(self) -> ArrayLike:
2054+
r"""Differential entropy of the Logistic distribution:
2055+
2056+
.. math::
2057+
H(X) = \ln(s) + 2
2058+
"""
18302059
return jnp.broadcast_to(jnp.log(self.scale) + 2, self.batch_shape)
18312060

18322061

0 commit comments

Comments
 (0)