Skip to content

Commit 597f91a

Browse files
committed
Implement transform for Wishart
1 parent 6182801 commit 597f91a

5 files changed

Lines changed: 437 additions & 170 deletions

File tree

pymc/distributions/multivariate.py

Lines changed: 158 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,19 @@
3535
)
3636
from pytensor.tensor.elemwise import DimShuffle
3737
from pytensor.tensor.exceptions import NotScalarConstantError
38-
from pytensor.tensor.linalg import det, eigh, solve_triangular, trace
38+
from pytensor.tensor.linalg import eigh, solve_triangular
3939
from pytensor.tensor.linalg import inv as matrix_inverse
4040
from pytensor.tensor.random import chisquare
4141
from pytensor.tensor.random.basic import MvNormalRV, dirichlet, multinomial, multivariate_normal
4242
from pytensor.tensor.random.op import RandomVariable
4343
from pytensor.tensor.random.utils import (
4444
normalize_size_param,
4545
)
46-
from scipy import stats
4746

4847
import pymc as pm
4948

5049
from pymc.distributions import transforms
51-
from pymc.distributions.continuous import BoundedContinuous, ChiSquared, Normal
50+
from pymc.distributions.continuous import BoundedContinuous
5251
from pymc.distributions.dist_math import (
5352
betaln,
5453
check_parameters,
@@ -74,6 +73,7 @@
7473
)
7574
from pymc.distributions.transforms import (
7675
CholeskyCorrTransform,
76+
CholeskyCovTransform,
7777
ZeroSumTransform,
7878
_default_transform,
7979
)
@@ -910,200 +910,221 @@ def matrix_pos_def(X, tol=1e-8):
910910
return pt.all(~pt.isnan(diag)) & pt.all(diag > tol)
911911

912912

913-
class WishartRV(RandomVariable):
913+
class WishartRV(SymbolicRandomVariable):
914914
name = "wishart"
915-
signature = "(),(p,p)->(p,p)"
916-
dtype = "floatX"
915+
extended_signature = "[rng],[size],(),(p,p)->[rng],(p,p)"
917916
_print_name = ("Wishart", "\\operatorname{Wishart}")
918917

918+
def update(self, node):
919+
return {node.inputs[0]: node.outputs[0]}
920+
919921
@classmethod
920-
def rng_fn(cls, rng, nu, V, size):
921-
scipy_size = size if size else 1 # Default size for Scipy's wishart.rvs is 1
922-
# Scipy doesn't accept batch nu or V
923-
nu = _squeeze_to_ndim(nu, 0)
924-
V = _squeeze_to_ndim(V, 2)
925-
result = stats.wishart.rvs(int(nu), V, size=scipy_size, random_state=rng)
926-
if size == (1,):
927-
return result[np.newaxis, ...]
922+
def rv_op(cls, nu, scale_chol, *, size=None, rng=None):
923+
nu = pt.as_tensor_variable(nu, dtype="int64")
924+
scale_chol = pt.as_tensor_variable(scale_chol)
925+
if scale_chol.type.ndim < 2:
926+
raise ValueError("Wishart `scale_chol` must have at least 2 dimensions")
927+
928+
rng = normalize_rng_param(rng)
929+
size = normalize_size_param(size)
930+
931+
p = scale_chol.shape[-1]
932+
933+
# Effective batch shape for the inner chi-squared / normal samples.
934+
# ``scale_chol`` keeps its original shape; the matmul ``scale_chol @ A``
935+
# broadcasts naturally. When the user does not pass an explicit ``size``,
936+
# the batch shape is inferred from broadcasting ``nu`` (ndim_supp 0) and
937+
# ``scale_chol`` (ndim_supp 2).
938+
if rv_size_is_none(size):
939+
batch_shape = tuple(implicit_size_from_params(nu, scale_chol, ndims_params=[0, 2]))
928940
else:
929-
return result
941+
batch_shape = tuple(size)
942+
943+
# Bartlett decomposition: A is a (..., p, p) lower-triangular matrix
944+
# with ``sqrt(chi^2_{nu - k})`` on the k-th diagonal entry and standard
945+
# normals strictly below. Then ``L_X = scale_chol @ A`` is the Cholesky
946+
# factor of a ``Wishart(nu, scale_chol scale_chol^T)`` draw, and the
947+
# SPD draw is ``L_X L_X^T``. ``nu`` may itself be batched, so we expand
948+
# a trailing axis for the per-diagonal subtraction.
949+
chi_dofs = nu[..., None] - pt.arange(p, dtype="int64") # (..., p)
950+
next_rng, chi_sq = pt.random.chisquare(
951+
df=chi_dofs,
952+
size=(*batch_shape, p),
953+
rng=rng,
954+
return_next_rng=True,
955+
)
956+
chi_diag = pt.sqrt(chi_sq) # (..., p)
957+
958+
n_offdiag = (p * (p - 1)) // 2
959+
next_rng, offdiag = pt.random.normal(
960+
loc=0.0,
961+
scale=1.0,
962+
size=(*batch_shape, n_offdiag),
963+
rng=next_rng,
964+
return_next_rng=True,
965+
)
966+
967+
A = pt.zeros((*batch_shape, p, p), dtype=scale_chol.dtype)
968+
diag_idx = pt.arange(p)
969+
A = A[..., diag_idx, diag_idx].set(chi_diag)
970+
tril_idx = pt.tril_indices(p, k=-1)
971+
A = A[..., tril_idx[0], tril_idx[1]].set(offdiag)
930972

973+
L_X = scale_chol @ A
974+
X = L_X @ L_X.mT
931975

932-
wishart = WishartRV()
976+
return cls(
977+
inputs=[rng, size, nu, scale_chol],
978+
outputs=[next_rng, X],
979+
)(rng, size, nu, scale_chol)
933980

934981

935982
class Wishart(Continuous):
936983
r"""
937984
Wishart distribution.
938985
939-
The Wishart distribution is the probability distribution of the
940-
maximum-likelihood estimator (MLE) of the precision matrix of a
941-
multivariate normal distribution. If V=1, the distribution is
942-
identical to the chi-square distribution with nu degrees of
943-
freedom.
986+
The Wishart distribution is the probability distribution over symmetric
987+
positive-definite matrices that arises as the distribution of the sum of
988+
outer products of i.i.d. multivariate normal vectors. If
989+
:math:`x_i \sim \mathcal{N}(0, V)` are i.i.d. for :math:`i = 1, \dots, \nu`,
990+
then :math:`X = \sum_i x_i x_i^\top \sim \mathrm{Wishart}_p(\nu, V)`.
944991
945992
.. math::
946993
947-
f(X \mid nu, T) =
948-
\frac{{\mid T \mid}^{nu/2}{\mid X \mid}^{(nu-k-1)/2}}{2^{nu k/2}
949-
\Gamma_p(nu/2)} \exp\left\{ -\frac{1}{2} Tr(TX) \right\}
994+
f(X \mid \nu, V) =
995+
\frac{|X|^{(\nu-p-1)/2}}{2^{\nu p / 2} |V|^{\nu / 2} \Gamma_p(\nu / 2)}
996+
\exp\left\{ -\frac{1}{2} \operatorname{tr}(V^{-1} X) \right\}
950997
951-
where :math:`k` is the rank of :math:`X`.
998+
where :math:`p` is the rank of :math:`X`.
952999
9531000
======== =========================================
954-
Support :math:`X(p x p)` positive definite matrix
955-
Mean :math:`nu V`
956-
Variance :math:`nu (v_{ij}^2 + v_{ii} v_{jj})`
1001+
Support :math:`X\,(p \times p)` positive definite matrix
1002+
Mean :math:`\nu V`
1003+
Variance :math:`\nu (v_{ij}^2 + v_{ii} v_{jj})`
9571004
======== =========================================
9581005
9591006
Parameters
9601007
----------
9611008
nu : tensor_like of int
962-
Degrees of freedom, > 0.
963-
V : tensor_like of float
964-
p x p positive definite matrix.
1009+
Degrees of freedom, must satisfy ``nu > p - 1``.
1010+
V : tensor_like of float, optional
1011+
``(p, p)`` symmetric positive-definite scale matrix. Mutually exclusive with
1012+
``scale_chol``.
1013+
scale_chol : tensor_like of float, optional
1014+
``(p, p)`` lower-triangular Cholesky factor of the scale matrix
1015+
(``V = scale_chol @ scale_chol.T``). Provide this when you already have the
1016+
decomposition to avoid a redundant Cholesky inside ``logp``. Mutually
1017+
exclusive with ``V``.
9651018
9661019
Notes
9671020
-----
968-
This distribution is unusable in a PyMC model. You should instead
969-
use LKJCholeskyCov or LKJCorr.
1021+
The default unconstraining transform is :class:`CholeskyCovTransform`, which
1022+
parameterizes ``X = L @ L.T`` from a free real vector with ``log L_kk`` on the
1023+
diagonal.
9701024
"""
9711025

972-
rv_op = wishart
1026+
rv_type = WishartRV
1027+
rv_op = WishartRV.rv_op
1028+
1029+
@classmethod
1030+
def _resolve_scale(cls, V, scale_chol):
1031+
if (V is None) == (scale_chol is None):
1032+
raise ValueError("Wishart requires exactly one of `V` or `scale_chol`.")
1033+
if scale_chol is not None:
1034+
return pt.as_tensor_variable(scale_chol)
1035+
return pt.linalg.cholesky(pt.as_tensor_variable(V))
9731036

9741037
@classmethod
975-
def dist(cls, nu, V, *args, **kwargs):
1038+
def dist(cls, nu, V=None, *args, scale_chol=None, **kwargs):
9761039
nu = pt.as_tensor_variable(nu, dtype=int)
977-
V = pt.as_tensor_variable(V)
978-
979-
warnings.warn(
980-
"The Wishart distribution can currently not be used "
981-
"for MCMC sampling. The probability of sampling a "
982-
"symmetric matrix is basically zero. Instead, please "
983-
"use LKJCholeskyCov or LKJCorr. For more information "
984-
"on the issues surrounding the Wishart see here: "
985-
"https://github.com/pymc-devs/pymc/issues/538.",
986-
UserWarning,
987-
)
1040+
scale_chol = cls._resolve_scale(V, scale_chol)
1041+
return super().dist([nu, scale_chol], *args, **kwargs)
9881042

989-
# mean = nu * V
990-
# p = V.shape[0]
991-
# mode = pt.switch(pt.ge(nu, p + 1), (nu - p - 1) * V, np.nan)
992-
return super().dist([nu, V], *args, **kwargs)
1043+
def support_point(rv, size, nu, scale_chol):
1044+
# Mean of Wishart(nu, V) is nu * V = nu * (L_V @ L_V.T). Always in the
1045+
# SPD support for valid nu > p - 1, so it's a safe initial point.
1046+
V = scale_chol @ scale_chol.mT
1047+
return nu.astype(V.dtype) * V
9931048

994-
def logp(X, nu, V):
1049+
def logp(X, nu, scale_chol):
9951050
"""
996-
Calculate logp of Wishart distribution at specified value.
1051+
Log-density of the Wishart distribution at the SPD value ``X``.
9971052
998-
Parameters
999-
----------
1000-
X: numeric
1001-
Value for which log-probability is calculated.
1002-
1003-
Returns
1004-
-------
1005-
TensorVariable
1053+
Implemented in Cholesky form: when the value comes from the unconstraining
1054+
``CholeskyCovTransform``, ``cholesky(L @ L.T)`` rewrites to ``L`` and no
1055+
decomposition runs at runtime.
10061056
"""
1007-
p = V.shape[0]
1057+
p = X.shape[-1]
10081058

1009-
IVI = det(V)
1010-
IXI = det(X)
1059+
L_X = pt.linalg.cholesky(X)
1060+
log_det_X = 2 * pt.sum(pt.log(pt.diagonal(L_X, axis1=-2, axis2=-1)), axis=-1)
1061+
log_det_V = 2 * pt.sum(pt.log(pt.diagonal(scale_chol, axis1=-2, axis2=-1)), axis=-1)
1062+
# tr(V^{-1} X) = ||L_V^{-1} L_X||_F^2 via a triangular solve.
1063+
M = solve_triangular(scale_chol, L_X, lower=True)
1064+
tr_term = pt.sum(M**2, axis=(-2, -1))
10111065

10121066
return check_parameters(
10131067
(
1014-
(nu - p - 1) * pt.log(IXI)
1015-
- trace(matrix_inverse(V).dot(X))
1068+
(nu - p - 1) * log_det_X
1069+
- tr_term
10161070
- nu * p * pt.log(2)
1017-
- nu * pt.log(IVI)
1071+
- nu * log_det_V
10181072
- 2 * multigammaln(nu / 2.0, p)
10191073
)
10201074
/ 2,
1021-
matrix_pos_def(X),
1022-
pt.eq(X, X.T),
10231075
nu > (p - 1),
1076+
msg="nu > p - 1",
10241077
)
10251078

10261079

1027-
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initval=None):
1028-
r"""
1029-
Bartlett decomposition of the Wishart distribution.
1030-
1031-
As the Wishart distribution requires the matrix to be symmetric positive
1032-
semi-definite, it is impossible for MCMC to ever propose acceptable matrices.
1033-
1034-
Instead, we can use the Barlett decomposition which samples a lower
1035-
diagonal matrix. Specifically:
1080+
@_default_transform.register(WishartRV)
1081+
def wishart_default_transform(op, rv):
1082+
_, _, _, scale_chol = rv.owner.inputs
1083+
n = scale_chol.shape[-1]
1084+
return CholeskyCovTransform(n=n)
10361085

1037-
.. math::
1038-
\text{If} L \sim \begin{pmatrix}
1039-
\sqrt{c_1} & 0 & 0 \\
1040-
z_{21} & \sqrt{c_2} & 0 \\
1041-
z_{31} & z_{32} & \sqrt{c_3}
1042-
\end{pmatrix}
10431086

1044-
\text{with} c_i \sim \chi^2(n-i+1) \text{ and } n_{ij} \sim \mathcal{N}(0, 1), \text{then} \\
1045-
L \times A \times A.T \times L.T \sim \text{Wishart}(L \times L.T, \nu)
1087+
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initval=None):
1088+
r"""
1089+
Bartlett-decomposed Wishart prior. **Deprecated**: use :class:`Wishart` directly.
10461090
1047-
See http://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition
1048-
for more information.
1091+
This used to be the only MCMC-usable Wishart in PyMC, since the legacy
1092+
:class:`Wishart` had no unconstraining transform. The new :class:`Wishart`
1093+
is parameterized through its Cholesky factor with a default
1094+
:class:`CholeskyCovTransform`, so this helper is no longer needed and is
1095+
a thin shim around it for backward compatibility.
10491096
10501097
Parameters
10511098
----------
10521099
S : ndarray
1053-
p x p positive definite matrix
1054-
Or:
1055-
p x p lower-triangular matrix that is the Cholesky factor
1056-
of the covariance matrix.
1100+
``(p, p)`` positive-definite scale matrix, or its lower-triangular
1101+
Cholesky factor when ``is_cholesky=True``.
10571102
nu : tensor_like of int
1058-
Degrees of freedom, > dim(S).
1059-
is_cholesky : bool, default=False
1060-
Input matrix S is already Cholesky decomposed as S.T * S
1061-
return_cholesky : bool, default=False
1062-
Only return the Cholesky decomposed matrix.
1063-
initval : ndarray
1064-
p x p positive definite matrix used to initialize
1065-
1066-
Notes
1067-
-----
1068-
This is not a standard Distribution class but follows a similar
1069-
interface. Besides the Wishart distribution, it will add RVs
1070-
name_c and name_z to your model which make up the matrix.
1071-
1072-
This distribution is usually a bad idea to use as a prior for multivariate
1073-
normal. You should instead use LKJCholeskyCov or LKJCorr.
1103+
Degrees of freedom, > ``dim(S) - 1``.
1104+
is_cholesky : bool, default False
1105+
If True, ``S`` is interpreted as the Cholesky factor of the scale matrix
1106+
(mapped to :class:`Wishart`'s ``scale_chol`` argument).
1107+
return_cholesky : bool, default False
1108+
If True, return the Cholesky factor of the Wishart draw rather than the
1109+
SPD matrix itself.
1110+
initval : ndarray, optional
1111+
Forwarded to :class:`Wishart`. Note that ``initval`` semantics changed:
1112+
the new :class:`Wishart` expects an SPD matrix value, not the legacy
1113+
Bartlett component initvals.
10741114
"""
1075-
L = S if is_cholesky else scipy.linalg.cholesky(S)
1076-
diag_idx = np.diag_indices_from(S)
1077-
tril_idx = np.tril_indices_from(S, k=-1)
1078-
n_diag = len(diag_idx[0])
1079-
n_tril = len(tril_idx[0])
1080-
1081-
if initval is not None:
1082-
# Inverse transform
1083-
initval = np.dot(np.dot(np.linalg.inv(L), initval), np.linalg.inv(L.T))
1084-
initval = scipy.linalg.cholesky(initval, lower=True)
1085-
diag_testval = initval[diag_idx] ** 2
1086-
tril_testval = initval[tril_idx]
1087-
else:
1088-
diag_testval = None
1089-
tril_testval = None
1090-
1091-
c = pt.sqrt(
1092-
ChiSquared(f"{name}_c", nu - np.arange(2, 2 + n_diag), shape=n_diag, initval=diag_testval)
1115+
warnings.warn(
1116+
"WishartBartlett is deprecated and will be removed in a future release. "
1117+
"Use pm.Wishart directly. For `is_cholesky=True`, pass `scale_chol=S`. "
1118+
"For `return_cholesky=True`, wrap pm.Wishart in pt.linalg.cholesky as a "
1119+
"Deterministic.",
1120+
FutureWarning,
1121+
stacklevel=2,
10931122
)
1094-
pm._log.info(f"Added new variable {name}_c to model diagonal of Wishart.")
1095-
z = Normal(f"{name}_z", 0.0, 1.0, shape=n_tril, initval=tril_testval)
1096-
pm._log.info(f"Added new variable {name}_z to model off-diagonals of Wishart.")
1097-
# Construct A matrix
1098-
A = pt.zeros(S.shape, dtype=np.float32)
1099-
A = pt.set_subtensor(A[diag_idx], c)
1100-
A = pt.set_subtensor(A[tril_idx], z)
1101-
1102-
# L * A * A.T * L.T ~ Wishart(L*L.T, nu)
1123+
scale_kwargs = {"scale_chol": S} if is_cholesky else {"V": S}
11031124
if return_cholesky:
1104-
return pm.Deterministic(name, pt.dot(L, A))
1105-
else:
1106-
return pm.Deterministic(name, pt.dot(pt.dot(pt.dot(L, A), A.T), L.T))
1125+
wishart_rv = Wishart(f"_{name}_wishart", nu=nu, initval=initval, **scale_kwargs)
1126+
return pm.Deterministic(name, pt.linalg.cholesky(wishart_rv))
1127+
return Wishart(name, nu=nu, initval=initval, **scale_kwargs)
11071128

11081129

11091130
def _lkj_normalizing_constant(eta, n):

0 commit comments

Comments
 (0)