@@ -841,7 +841,7 @@ def variance(self) -> ArrayLike:
841841 r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then
842842
843843 .. math::
844- \mathrm{Var}[X] = \frac{\alpha}{\lambda^2}
844+ \mathrm{Var}(X) = \frac{\alpha}{\lambda^2}
845845 """
846846 return self .concentration / jnp .power (self .rate , 2 )
847847
@@ -1306,6 +1306,22 @@ def mean(self) -> ArrayLike:
13061306
13071307
13081308class Gumbel (Distribution ):
1309+ r"""The Gumbel (maximum) distribution, a continuous real-valued
1310+ distribution parameterized by location :math:`\mu` and scale :math:`\beta > 0`.
1311+ It is the limiting distribution of the maximum of a large number of i.i.d.
1312+ samples from an exponential-tailed distribution.
1313+
1314+ The Probability Density Function (PDF) is:
1315+
1316+ .. math::
1317+ f(x \mid \mu, \beta) = \frac{1}{\beta} \exp\!\left(
1318+ -\frac{x - \mu}{\beta} - \exp\!\left(-\frac{x - \mu}{\beta}\right)
1319+ \right), \quad x \in \mathbb{R}
1320+
1321+ where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and
1322+ :math:`\beta > 0` is the scale (:attr:`scale`).
1323+ """
1324+
13091325 arg_constraints = {"loc" : constraints .real , "scale" : constraints .positive }
13101326 support = constraints .real
13111327 reparametrized_params = ["loc" , "scale" ]
@@ -1317,6 +1333,11 @@ def __init__(
13171333 * ,
13181334 validate_args : Optional [bool ] = None ,
13191335 ) -> None :
1336+ r"""
1337+ :param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``.
1338+ :param scale: Scale parameter :math:`\beta > 0`. Defaults to ``1.0``.
1339+ :param validate_args: If True, enforce domain constraints during initialization.
1340+ """
13201341 self .loc , self .scale = promote_shapes (loc , scale )
13211342 batch_shape = lax .broadcast_shapes (jnp .shape (loc ), jnp .shape (scale ))
13221343
@@ -1325,6 +1346,15 @@ def __init__(
13251346 )
13261347
13271348 def sample (self , key : jax .Array , sample_shape : tuple [int , ...] = ()) -> ArrayLike :
1349+ r"""Draw samples from the Gumbel distribution via the location-scale
1350+ transform :math:`X = \mu + \beta Z`, where
1351+ :math:`Z \sim \mathrm{Gumbel}(0, 1)` is drawn from
1352+ :func:`~jax.random.gumbel`.
1353+
1354+ :param key: A JAX PRNG key.
1355+ :param sample_shape: Sample dimensions to prepend to the batch shape.
1356+ :return: Real-valued samples from the Gumbel distribution.
1357+ """
13281358 assert is_prng_key (key )
13291359 standard_gumbel_sample = random .gumbel (
13301360 key , shape = sample_shape + self .batch_shape + self .event_shape
@@ -1333,23 +1363,63 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
13331363
13341364 @validate_sample
13351365 def log_prob (self , value : ArrayLike ) -> ArrayLike :
1366+ r"""Evaluate the log probability density function at ``value``.
1367+
1368+ Letting :math:`z = (x - \mu)/\beta`,
1369+
1370+ .. math::
1371+ \ln f(x \mid \mu, \beta) = -z - e^{-z} - \ln \beta
1372+
1373+ :param value: Real-valued point :math:`x` at which to evaluate the log PDF.
1374+ :return: Log probability density evaluated under the Gumbel distribution.
1375+ """
13361376 z = (value - self .loc ) / self .scale
13371377 return - (z + jnp .exp (- z )) - jnp .log (self .scale )
13381378
13391379 @property
13401380 def mean (self ) -> ArrayLike :
1381+ r"""Mean of the Gumbel distribution:
1382+
1383+ .. math::
1384+ \mathbb{E}[X] = \mu + \beta \gamma
1385+
1386+ where :math:`\gamma \approx 0.5772\ldots` is the Euler-Mascheroni constant,
1387+ available at, :data:`~jax.numpy.euler_gamma`.
1388+ """
13411389 return jnp .broadcast_to (
13421390 self .loc + self .scale * jnp .euler_gamma , self .batch_shape
13431391 )
13441392
13451393 @property
13461394 def variance (self ) -> ArrayLike :
1395+ r"""Variance of the Gumbel distribution:
1396+
1397+ .. math::
1398+ \mathrm{Var}(X) = \frac{\pi^2}{6} \beta^2
1399+ """
13471400 return jnp .broadcast_to (jnp .pi ** 2 / 6.0 * self .scale ** 2 , self .batch_shape )
13481401
13491402 def cdf (self , value : ArrayLike ) -> ArrayLike :
1403+ r"""Cumulative Distribution Function (CDF) of the Gumbel distribution:
1404+
1405+ .. math::
1406+ F(x \mid \mu, \beta) = \exp\!\left(-\exp\!\left(-\frac{x - \mu}{\beta}\right)\right)
1407+
1408+ :param value: Real-valued point :math:`x` at which to evaluate the CDF.
1409+ :return: CDF values in :math:`[0, 1]`.
1410+ """
13501411 return jnp .exp (- jnp .exp ((self .loc - value ) / self .scale ))
13511412
13521413 def icdf (self , q : ArrayLike ) -> ArrayLike :
1414+ r"""Inverse CDF (quantile function) of the Gumbel distribution:
1415+
1416+ .. math::
1417+ F^{-1}(q \mid \mu, \beta) = \mu - \beta \ln(-\ln q),
1418+ \quad q \in (0, 1)
1419+
1420+ :param q: Quantile values in :math:`(0, 1)`.
1421+ :return: Real-valued quantiles of the Gumbel distribution at ``q``.
1422+ """
13531423 return self .loc - self .scale * jnp .log (- jnp .log (q ))
13541424
13551425
@@ -1415,6 +1485,22 @@ def variance(self) -> ArrayLike:
14151485
14161486
14171487class Laplace (Distribution ):
1488+ r"""The Laplace (double-exponential) distribution, a continuous real-valued
1489+ distribution parameterized by location :math:`\mu` and scale :math:`b > 0`.
1490+ It is the distribution of the difference of two i.i.d. exponential variates
1491+ and has heavier tails than the Normal distribution.
1492+
1493+ The Probability Density Function (PDF) is:
1494+
1495+ .. math::
1496+ f(x \mid \mu, b) = \frac{1}{2 b}
1497+ \exp\!\left(-\frac{|x - \mu|}{b}\right),
1498+ \quad x \in \mathbb{R}
1499+
1500+ where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and
1501+ :math:`b > 0` is the scale (:attr:`scale`).
1502+ """
1503+
14181504 arg_constraints = {"loc" : constraints .real , "scale" : constraints .positive }
14191505 support = constraints .real
14201506 reparametrized_params = ["loc" , "scale" ]
@@ -1426,13 +1512,26 @@ def __init__(
14261512 * ,
14271513 validate_args : Optional [bool ] = None ,
14281514 ) -> None :
1515+ r"""
1516+ :param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``.
1517+ :param scale: Scale parameter :math:`b > 0`. Defaults to ``1.0``.
1518+ :param validate_args: If True, enforce domain constraints during initialization.
1519+ """
14291520 self .loc , self .scale = promote_shapes (loc , scale )
14301521 batch_shape = lax .broadcast_shapes (jnp .shape (loc ), jnp .shape (scale ))
14311522 super (Laplace , self ).__init__ (
14321523 batch_shape = batch_shape , validate_args = validate_args
14331524 )
14341525
14351526 def sample (self , key : jax .Array , sample_shape : tuple [int , ...] = ()) -> ArrayLike :
1527+ r"""Draw samples via the location-scale transform
1528+ :math:`X = \mu + b Z`, where :math:`Z \sim \mathrm{Laplace}(0, 1)` is
1529+ drawn from :func:`~jax.random.laplace`.
1530+
1531+ :param key: A JAX PRNG key.
1532+ :param sample_shape: Sample dimensions to prepend to the batch shape.
1533+ :return: Real-valued samples from the Laplace distribution.
1534+ """
14361535 assert is_prng_key (key )
14371536 eps = random .laplace (
14381537 key , shape = sample_shape + self .batch_shape + self .event_shape
@@ -1441,27 +1540,73 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
14411540
14421541 @validate_sample
14431542 def log_prob (self , value : ArrayLike ) -> ArrayLike :
1543+ r"""Evaluate the log probability density function at ``value``:
1544+
1545+ .. math::
1546+ \ln f(x \mid \mu, b) = -\frac{|x - \mu|}{b} - \ln(2 b)
1547+
1548+ :param value: Real-valued point :math:`x` at which to evaluate the log PDF.
1549+ :return: Log probability density evaluated under the Laplace distribution.
1550+ """
14441551 normalize_term = jnp .log (2 * self .scale )
14451552 value_scaled = jnp .abs (value - self .loc ) / self .scale
14461553 return - value_scaled - normalize_term
14471554
14481555 @property
14491556 def mean (self ) -> ArrayLike :
1557+ r"""Mean of the Laplace distribution:
1558+
1559+ .. math::
1560+ \mathbb{E}[X] = \mu
1561+ """
14501562 return jnp .broadcast_to (self .loc , self .batch_shape )
14511563
14521564 @property
14531565 def variance (self ) -> ArrayLike :
1566+ r"""Variance of the Laplace distribution:
1567+
1568+ .. math::
1569+ \mathrm{Var}(X) = 2 b^2
1570+ """
14541571 return jnp .broadcast_to (2 * self .scale ** 2 , self .batch_shape )
14551572
14561573 def cdf (self , value : ArrayLike ) -> ArrayLike :
1574+ r"""Cumulative Distribution Function (CDF) of the Laplace distribution.
1575+ Letting :math:`z = (x - \mu)/b`,
1576+
1577+ .. math::
1578+ F(x \mid \mu, b) = \frac{1}{2} - \frac{1}{2}\,
1579+ \operatorname{sgn}(z)\,\left(e^{-|z|} - 1\right)
1580+
1581+ The implementation uses :func:`~jax.numpy.expm1` for numerical
1582+ stability near :math:`z = 0`.
1583+
1584+ :param value: Real-valued point :math:`x` at which to evaluate the CDF.
1585+ :return: CDF values in :math:`[0, 1]`.
1586+ """
14571587 scaled = (value - self .loc ) / self .scale
14581588 return 0.5 - 0.5 * jnp .sign (scaled ) * jnp .expm1 (- jnp .abs (scaled ))
14591589
14601590 def icdf (self , q : ArrayLike ) -> ArrayLike :
1591+ r"""Inverse CDF (quantile function) of the Laplace distribution:
1592+
1593+ .. math::
1594+ F^{-1}(q \mid \mu, b) = \mu - b\,\mathrm{sgn}\left(q - \frac{1}{2}\right)\,
1595+ \ln\!\left(1 - 2 \left| q - \frac{1}{2} \right| \right),
1596+ \quad q \in [0, 1]
1597+
1598+ :param q: Quantile values in :math:`[0, 1]`.
1599+ :return: Real-valued quantiles of the Laplace distribution at ``q``.
1600+ """
14611601 a = q - 0.5
14621602 return self .loc - self .scale * jnp .sign (a ) * jnp .log1p (- 2 * jnp .abs (a ))
14631603
14641604 def entropy (self ) -> ArrayLike :
1605+ r"""Differential entropy of the Laplace distribution:
1606+
1607+ .. math::
1608+ H(X) = \ln(2 b) + 1
1609+ """
14651610 return jnp .log (2 * self .scale ) + 1
14661611
14671612
@@ -1782,6 +1927,25 @@ def entropy(self) -> ArrayLike:
17821927
17831928
17841929class Logistic (Distribution ):
1930+ r"""The Logistic distribution, a continuous real-valued distribution
1931+ parameterized by location :math:`\mu` and scale :math:`s > 0`. Its CDF is
1932+ the standard logistic (sigmoid) function shifted and scaled to :math:`\mu`,
1933+ :math:`s`, which makes it the natural latent distribution underlying
1934+ logistic regression.
1935+
1936+ The Probability Density Function (PDF) is:
1937+
1938+ .. math::
1939+ f(x \mid \mu, s) = \frac{
1940+ \exp\!\left(-\displaystyle\frac{x - \mu}{s}\right)
1941+ }{
1942+ s \left(1 + \exp\!\left(-\displaystyle\frac{x - \mu}{s}\right)\right)^{2}
1943+ }, \quad x \in \mathbb{R}
1944+
1945+ where :math:`\mu \in \mathbb{R}` is the location (:attr:`loc`) and
1946+ :math:`s > 0` is the scale (:attr:`scale`).
1947+ """
1948+
17851949 arg_constraints = {"loc" : constraints .real , "scale" : constraints .positive }
17861950 support = constraints .real
17871951 reparametrized_params = ["loc" , "scale" ]
@@ -1793,11 +1957,24 @@ def __init__(
17931957 * ,
17941958 validate_args : Optional [bool ] = None ,
17951959 ) -> None :
1960+ r"""
1961+ :param loc: Location parameter :math:`\mu \in \mathbb{R}`. Defaults to ``0.0``.
1962+ :param scale: Scale parameter :math:`s > 0`. Defaults to ``1.0``.
1963+ :param validate_args: If True, enforce domain constraints during initialization.
1964+ """
17961965 self .loc , self .scale = promote_shapes (loc , scale )
17971966 batch_shape = lax .broadcast_shapes (jnp .shape (loc ), jnp .shape (scale ))
17981967 super (Logistic , self ).__init__ (batch_shape , validate_args = validate_args )
17991968
18001969 def sample (self , key : jax .Array , sample_shape : tuple [int , ...] = ()) -> ArrayLike :
1970+ r"""Draw samples via the location-scale transform
1971+ :math:`X = \mu + s Z`, where :math:`Z \sim \mathrm{Logistic}(0, 1)` is
1972+ drawn from :func:`~jax.random.logistic`.
1973+
1974+ :param key: A JAX PRNG key.
1975+ :param sample_shape: Sample dimensions to prepend to the batch shape.
1976+ :return: Real-valued samples from the Logistic distribution.
1977+ """
18011978 assert is_prng_key (key )
18021979 z = random .logistic (
18031980 key , shape = sample_shape + self .batch_shape + self .event_shape
@@ -1806,27 +1983,79 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
18061983
18071984 @validate_sample
18081985 def log_prob (self , value : ArrayLike ) -> ArrayLike :
1986+ r"""Evaluate the log probability density function at ``value``.
1987+
1988+ Letting :math:`u = (\mu - x)/s`, the log PDF is
1989+
1990+ .. math::
1991+ \ln f(x \mid \mu, s) = u - \ln s - 2 \ln(1 + e^{u})
1992+
1993+ The implementation uses :func:`~jax.nn.softplus` for
1994+ :math:`\ln(1 + e^{u})`, which is numerically stable for both large
1995+ positive and large negative values of :math:`u`.
1996+
1997+ :param value: Real-valued point :math:`x` at which to evaluate the log PDF.
1998+ :return: Log probability density evaluated under the Logistic distribution.
1999+ """
18092000 log_exponent = (self .loc - value ) / self .scale
18102001 log_denominator = jnp .log (self .scale ) + 2 * nn .softplus (log_exponent )
18112002 return log_exponent - log_denominator
18122003
18132004 @property
18142005 def mean (self ) -> ArrayLike :
2006+ r"""Mean of the Logistic distribution:
2007+
2008+ .. math::
2009+ \mathbb{E}[X] = \mu
2010+ """
18152011 return jnp .broadcast_to (self .loc , self .batch_shape )
18162012
18172013 @property
18182014 def variance (self ) -> ArrayLike :
2015+ r"""Variance of the Logistic distribution:
2016+
2017+ .. math::
2018+ \mathrm{Var}(X) = \frac{\pi^2 s^2}{3}
2019+ """
18192020 var = (self .scale ** 2 ) * (jnp .pi ** 2 ) / 3
18202021 return jnp .broadcast_to (var , self .batch_shape )
18212022
18222023 def cdf (self , value : ArrayLike ) -> ArrayLike :
2024+ r"""Cumulative Distribution Function (CDF) of the Logistic distribution.
2025+ Letting :math:`z = (x - \mu)/s`,
2026+
2027+ .. math::
2028+ F(x \mid \mu, s) = \sigma(z) = \frac{1}{1 + e^{-z}}
2029+
2030+ where :math:`\sigma` is the logistic sigmoid, computed via
2031+ :func:`~jax.scipy.special.expit`.
2032+
2033+ :param value: Real-valued point :math:`x` at which to evaluate the CDF.
2034+ :return: CDF values in :math:`[0, 1]`.
2035+ """
18232036 scaled = (value - self .loc ) / self .scale
18242037 return expit (scaled )
18252038
18262039 def icdf (self , q : ArrayLike ) -> ArrayLike :
2040+ r"""Inverse CDF (quantile function) of the Logistic distribution:
2041+
2042+ .. math::
2043+ F^{-1}(q \mid \mu, s) = \mu + s\,\operatorname{logit}(q),
2044+ \quad q \in [0, 1]
2045+
2046+ where :math:`\operatorname{logit}(q) = \ln(q / (1 - q))`.
2047+
2048+ :param q: Quantile values in :math:`[0, 1]`.
2049+ :return: Real-valued quantiles of the Logistic distribution at ``q``.
2050+ """
18272051 return self .loc + self .scale * logit (q )
18282052
18292053 def entropy (self ) -> ArrayLike :
2054+ r"""Differential entropy of the Logistic distribution:
2055+
2056+ .. math::
2057+ H(X) = \ln(s) + 2
2058+ """
18302059 return jnp .broadcast_to (jnp .log (self .scale ) + 2 , self .batch_shape )
18312060
18322061
0 commit comments