|
35 | 35 | ) |
36 | 36 | from pytensor.tensor.elemwise import DimShuffle |
37 | 37 | from pytensor.tensor.exceptions import NotScalarConstantError |
38 | | -from pytensor.tensor.linalg import det, eigh, solve_triangular, trace |
| 38 | +from pytensor.tensor.linalg import eigh, solve_triangular |
39 | 39 | from pytensor.tensor.linalg import inv as matrix_inverse |
40 | 40 | from pytensor.tensor.random import chisquare |
41 | 41 | from pytensor.tensor.random.basic import MvNormalRV, dirichlet, multinomial, multivariate_normal |
42 | 42 | from pytensor.tensor.random.op import RandomVariable |
43 | 43 | from pytensor.tensor.random.utils import ( |
44 | 44 | normalize_size_param, |
45 | 45 | ) |
46 | | -from scipy import stats |
47 | 46 |
|
48 | 47 | import pymc as pm |
49 | 48 |
|
50 | 49 | from pymc.distributions import transforms |
51 | | -from pymc.distributions.continuous import BoundedContinuous, ChiSquared, Normal |
| 50 | +from pymc.distributions.continuous import BoundedContinuous |
52 | 51 | from pymc.distributions.dist_math import ( |
53 | 52 | betaln, |
54 | 53 | check_parameters, |
|
74 | 73 | ) |
75 | 74 | from pymc.distributions.transforms import ( |
76 | 75 | CholeskyCorrTransform, |
| 76 | + CholeskyCovTransform, |
77 | 77 | ZeroSumTransform, |
78 | 78 | _default_transform, |
79 | 79 | ) |
@@ -910,200 +910,221 @@ def matrix_pos_def(X, tol=1e-8): |
910 | 910 | return pt.all(~pt.isnan(diag)) & pt.all(diag > tol) |
911 | 911 |
|
912 | 912 |
|
913 | | -class WishartRV(RandomVariable): |
| 913 | +class WishartRV(SymbolicRandomVariable): |
914 | 914 | name = "wishart" |
915 | | - signature = "(),(p,p)->(p,p)" |
916 | | - dtype = "floatX" |
| 915 | + extended_signature = "[rng],[size],(),(p,p)->[rng],(p,p)" |
917 | 916 | _print_name = ("Wishart", "\\operatorname{Wishart}") |
918 | 917 |
|
| 918 | + def update(self, node): |
| 919 | + return {node.inputs[0]: node.outputs[0]} |
| 920 | + |
919 | 921 | @classmethod |
920 | | - def rng_fn(cls, rng, nu, V, size): |
921 | | - scipy_size = size if size else 1 # Default size for Scipy's wishart.rvs is 1 |
922 | | - # Scipy doesn't accept batch nu or V |
923 | | - nu = _squeeze_to_ndim(nu, 0) |
924 | | - V = _squeeze_to_ndim(V, 2) |
925 | | - result = stats.wishart.rvs(int(nu), V, size=scipy_size, random_state=rng) |
926 | | - if size == (1,): |
927 | | - return result[np.newaxis, ...] |
| 922 | + def rv_op(cls, nu, scale_chol, *, size=None, rng=None): |
| 923 | + nu = pt.as_tensor_variable(nu, dtype="int64") |
| 924 | + scale_chol = pt.as_tensor_variable(scale_chol) |
| 925 | + if scale_chol.type.ndim < 2: |
| 926 | + raise ValueError("Wishart `scale_chol` must have at least 2 dimensions") |
| 927 | + |
| 928 | + rng = normalize_rng_param(rng) |
| 929 | + size = normalize_size_param(size) |
| 930 | + |
| 931 | + p = scale_chol.shape[-1] |
| 932 | + |
| 933 | + # Effective batch shape for the inner chi-squared / normal samples. |
| 934 | + # ``scale_chol`` keeps its original shape; the matmul ``scale_chol @ A`` |
| 935 | + # broadcasts naturally. When the user does not pass an explicit ``size``, |
| 936 | + # the batch shape is inferred from broadcasting ``nu`` (ndim_supp 0) and |
| 937 | + # ``scale_chol`` (ndim_supp 2). |
| 938 | + if rv_size_is_none(size): |
| 939 | + batch_shape = tuple(implicit_size_from_params(nu, scale_chol, ndims_params=[0, 2])) |
928 | 940 | else: |
929 | | - return result |
| 941 | + batch_shape = tuple(size) |
| 942 | + |
| 943 | + # Bartlett decomposition: A is a (..., p, p) lower-triangular matrix |
| 944 | + # with ``sqrt(chi^2_{nu - k})`` on the k-th diagonal entry and standard |
| 945 | + # normals strictly below. Then ``L_X = scale_chol @ A`` is the Cholesky |
| 946 | + # factor of a ``Wishart(nu, scale_chol scale_chol^T)`` draw, and the |
| 947 | + # SPD draw is ``L_X L_X^T``. ``nu`` may itself be batched, so we expand |
| 948 | + # a trailing axis for the per-diagonal subtraction. |
| 949 | + chi_dofs = nu[..., None] - pt.arange(p, dtype="int64") # (..., p) |
| 950 | + next_rng, chi_sq = pt.random.chisquare( |
| 951 | + df=chi_dofs, |
| 952 | + size=(*batch_shape, p), |
| 953 | + rng=rng, |
| 954 | + return_next_rng=True, |
| 955 | + ) |
| 956 | + chi_diag = pt.sqrt(chi_sq) # (..., p) |
| 957 | + |
| 958 | + n_offdiag = (p * (p - 1)) // 2 |
| 959 | + next_rng, offdiag = pt.random.normal( |
| 960 | + loc=0.0, |
| 961 | + scale=1.0, |
| 962 | + size=(*batch_shape, n_offdiag), |
| 963 | + rng=next_rng, |
| 964 | + return_next_rng=True, |
| 965 | + ) |
| 966 | + |
| 967 | + A = pt.zeros((*batch_shape, p, p), dtype=scale_chol.dtype) |
| 968 | + diag_idx = pt.arange(p) |
| 969 | + A = A[..., diag_idx, diag_idx].set(chi_diag) |
| 970 | + tril_idx = pt.tril_indices(p, k=-1) |
| 971 | + A = A[..., tril_idx[0], tril_idx[1]].set(offdiag) |
930 | 972 |
|
| 973 | + L_X = scale_chol @ A |
| 974 | + X = L_X @ L_X.mT |
931 | 975 |
|
932 | | -wishart = WishartRV() |
| 976 | + return cls( |
| 977 | + inputs=[rng, size, nu, scale_chol], |
| 978 | + outputs=[next_rng, X], |
| 979 | + )(rng, size, nu, scale_chol) |
933 | 980 |
|
934 | 981 |
|
935 | 982 | class Wishart(Continuous): |
936 | 983 | r""" |
937 | 984 | Wishart distribution. |
938 | 985 |
|
939 | | - The Wishart distribution is the probability distribution of the |
940 | | - maximum-likelihood estimator (MLE) of the precision matrix of a |
941 | | - multivariate normal distribution. If V=1, the distribution is |
942 | | - identical to the chi-square distribution with nu degrees of |
943 | | - freedom. |
| 986 | + The Wishart distribution is the probability distribution over symmetric |
| 987 | + positive-definite matrices that arises as the distribution of the sum of |
| 988 | + outer products of i.i.d. multivariate normal vectors. If |
| 989 | + :math:`x_i \sim \mathcal{N}(0, V)` are i.i.d. for :math:`i = 1, \dots, \nu`, |
| 990 | + then :math:`X = \sum_i x_i x_i^\top \sim \mathrm{Wishart}_p(\nu, V)`. |
944 | 991 |
|
945 | 992 | .. math:: |
946 | 993 |
|
947 | | - f(X \mid nu, T) = |
948 | | - \frac{{\mid T \mid}^{nu/2}{\mid X \mid}^{(nu-k-1)/2}}{2^{nu k/2} |
949 | | - \Gamma_p(nu/2)} \exp\left\{ -\frac{1}{2} Tr(TX) \right\} |
| 994 | + f(X \mid \nu, V) = |
| 995 | + \frac{|X|^{(\nu-p-1)/2}}{2^{\nu p / 2} |V|^{\nu / 2} \Gamma_p(\nu / 2)} |
| 996 | + \exp\left\{ -\frac{1}{2} \operatorname{tr}(V^{-1} X) \right\} |
950 | 997 |
|
951 | | - where :math:`k` is the rank of :math:`X`. |
| 998 | + where :math:`p` is the rank of :math:`X`. |
952 | 999 |
|
953 | 1000 | ======== ========================================= |
954 | | - Support :math:`X(p x p)` positive definite matrix |
955 | | - Mean :math:`nu V` |
956 | | - Variance :math:`nu (v_{ij}^2 + v_{ii} v_{jj})` |
| 1001 | + Support :math:`X\,(p \times p)` positive definite matrix |
| 1002 | + Mean :math:`\nu V` |
| 1003 | + Variance :math:`\nu (v_{ij}^2 + v_{ii} v_{jj})` |
957 | 1004 | ======== ========================================= |
958 | 1005 |
|
959 | 1006 | Parameters |
960 | 1007 | ---------- |
961 | 1008 | nu : tensor_like of int |
962 | | - Degrees of freedom, > 0. |
963 | | - V : tensor_like of float |
964 | | - p x p positive definite matrix. |
| 1009 | + Degrees of freedom, must satisfy ``nu > p - 1``. |
| 1010 | + V : tensor_like of float, optional |
| 1011 | + ``(p, p)`` symmetric positive-definite scale matrix. Mutually exclusive with |
| 1012 | + ``scale_chol``. |
| 1013 | + scale_chol : tensor_like of float, optional |
| 1014 | + ``(p, p)`` lower-triangular Cholesky factor of the scale matrix |
| 1015 | + (``V = scale_chol @ scale_chol.T``). Provide this when you already have the |
| 1016 | + decomposition to avoid a redundant Cholesky inside ``logp``. Mutually |
| 1017 | + exclusive with ``V``. |
965 | 1018 |
|
966 | 1019 | Notes |
967 | 1020 | ----- |
968 | | - This distribution is unusable in a PyMC model. You should instead |
969 | | - use LKJCholeskyCov or LKJCorr. |
| 1021 | + The default unconstraining transform is :class:`CholeskyCovTransform`, which |
| 1022 | + parameterizes ``X = L @ L.T`` from a free real vector with ``log L_kk`` on the |
| 1023 | + diagonal. |
970 | 1024 | """ |
971 | 1025 |
|
972 | | - rv_op = wishart |
| 1026 | + rv_type = WishartRV |
| 1027 | + rv_op = WishartRV.rv_op |
| 1028 | + |
| 1029 | + @classmethod |
| 1030 | + def _resolve_scale(cls, V, scale_chol): |
| 1031 | + if (V is None) == (scale_chol is None): |
| 1032 | + raise ValueError("Wishart requires exactly one of `V` or `scale_chol`.") |
| 1033 | + if scale_chol is not None: |
| 1034 | + return pt.as_tensor_variable(scale_chol) |
| 1035 | + return pt.linalg.cholesky(pt.as_tensor_variable(V)) |
973 | 1036 |
|
974 | 1037 | @classmethod |
975 | | - def dist(cls, nu, V, *args, **kwargs): |
| 1038 | + def dist(cls, nu, V=None, *args, scale_chol=None, **kwargs): |
976 | 1039 | nu = pt.as_tensor_variable(nu, dtype=int) |
977 | | - V = pt.as_tensor_variable(V) |
978 | | - |
979 | | - warnings.warn( |
980 | | - "The Wishart distribution can currently not be used " |
981 | | - "for MCMC sampling. The probability of sampling a " |
982 | | - "symmetric matrix is basically zero. Instead, please " |
983 | | - "use LKJCholeskyCov or LKJCorr. For more information " |
984 | | - "on the issues surrounding the Wishart see here: " |
985 | | - "https://github.com/pymc-devs/pymc/issues/538.", |
986 | | - UserWarning, |
987 | | - ) |
| 1040 | + scale_chol = cls._resolve_scale(V, scale_chol) |
| 1041 | + return super().dist([nu, scale_chol], *args, **kwargs) |
988 | 1042 |
|
989 | | - # mean = nu * V |
990 | | - # p = V.shape[0] |
991 | | - # mode = pt.switch(pt.ge(nu, p + 1), (nu - p - 1) * V, np.nan) |
992 | | - return super().dist([nu, V], *args, **kwargs) |
| 1043 | + def support_point(rv, size, nu, scale_chol): |
| 1044 | + # Mean of Wishart(nu, V) is nu * V = nu * (L_V @ L_V.T). Always in the |
| 1045 | + # SPD support for valid nu > p - 1, so it's a safe initial point. |
| 1046 | + V = scale_chol @ scale_chol.mT |
| 1047 | + return nu.astype(V.dtype) * V |
993 | 1048 |
|
994 | | - def logp(X, nu, V): |
| 1049 | + def logp(X, nu, scale_chol): |
995 | 1050 | """ |
996 | | - Calculate logp of Wishart distribution at specified value. |
| 1051 | + Log-density of the Wishart distribution at the SPD value ``X``. |
997 | 1052 |
|
998 | | - Parameters |
999 | | - ---------- |
1000 | | - X: numeric |
1001 | | - Value for which log-probability is calculated. |
1002 | | -
|
1003 | | - Returns |
1004 | | - ------- |
1005 | | - TensorVariable |
| 1053 | + Implemented in Cholesky form: when the value comes from the unconstraining |
| 1054 | + ``CholeskyCovTransform``, ``cholesky(L @ L.T)`` rewrites to ``L`` and no |
| 1055 | + decomposition runs at runtime. |
1006 | 1056 | """ |
1007 | | - p = V.shape[0] |
| 1057 | + p = X.shape[-1] |
1008 | 1058 |
|
1009 | | - IVI = det(V) |
1010 | | - IXI = det(X) |
| 1059 | + L_X = pt.linalg.cholesky(X) |
| 1060 | + log_det_X = 2 * pt.sum(pt.log(pt.diagonal(L_X, axis1=-2, axis2=-1)), axis=-1) |
| 1061 | + log_det_V = 2 * pt.sum(pt.log(pt.diagonal(scale_chol, axis1=-2, axis2=-1)), axis=-1) |
| 1062 | + # tr(V^{-1} X) = ||L_V^{-1} L_X||_F^2 via a triangular solve. |
| 1063 | + M = solve_triangular(scale_chol, L_X, lower=True) |
| 1064 | + tr_term = pt.sum(M**2, axis=(-2, -1)) |
1011 | 1065 |
|
1012 | 1066 | return check_parameters( |
1013 | 1067 | ( |
1014 | | - (nu - p - 1) * pt.log(IXI) |
1015 | | - - trace(matrix_inverse(V).dot(X)) |
| 1068 | + (nu - p - 1) * log_det_X |
| 1069 | + - tr_term |
1016 | 1070 | - nu * p * pt.log(2) |
1017 | | - - nu * pt.log(IVI) |
| 1071 | + - nu * log_det_V |
1018 | 1072 | - 2 * multigammaln(nu / 2.0, p) |
1019 | 1073 | ) |
1020 | 1074 | / 2, |
1021 | | - matrix_pos_def(X), |
1022 | | - pt.eq(X, X.T), |
1023 | 1075 | nu > (p - 1), |
| 1076 | + msg="nu > p - 1", |
1024 | 1077 | ) |
1025 | 1078 |
|
1026 | 1079 |
|
1027 | | -def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initval=None): |
1028 | | - r""" |
1029 | | - Bartlett decomposition of the Wishart distribution. |
1030 | | -
|
1031 | | - As the Wishart distribution requires the matrix to be symmetric positive |
1032 | | - semi-definite, it is impossible for MCMC to ever propose acceptable matrices. |
1033 | | -
|
1034 | | - Instead, we can use the Barlett decomposition which samples a lower |
1035 | | - diagonal matrix. Specifically: |
| 1080 | +@_default_transform.register(WishartRV) |
| 1081 | +def wishart_default_transform(op, rv): |
| 1082 | + _, _, _, scale_chol = rv.owner.inputs |
| 1083 | + n = scale_chol.shape[-1] |
| 1084 | + return CholeskyCovTransform(n=n) |
1036 | 1085 |
|
1037 | | - .. math:: |
1038 | | - \text{If} L \sim \begin{pmatrix} |
1039 | | - \sqrt{c_1} & 0 & 0 \\ |
1040 | | - z_{21} & \sqrt{c_2} & 0 \\ |
1041 | | - z_{31} & z_{32} & \sqrt{c_3} |
1042 | | - \end{pmatrix} |
1043 | 1086 |
|
1044 | | - \text{with} c_i \sim \chi^2(n-i+1) \text{ and } n_{ij} \sim \mathcal{N}(0, 1), \text{then} \\ |
1045 | | - L \times A \times A.T \times L.T \sim \text{Wishart}(L \times L.T, \nu) |
| 1087 | +def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initval=None): |
| 1088 | + r""" |
| 1089 | + Bartlett-decomposed Wishart prior. **Deprecated**: use :class:`Wishart` directly. |
1046 | 1090 |
|
1047 | | - See http://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition |
1048 | | - for more information. |
| 1091 | + This used to be the only MCMC-usable Wishart in PyMC, since the legacy |
| 1092 | + :class:`Wishart` had no unconstraining transform. The new :class:`Wishart` |
| 1093 | + is parameterized through its Cholesky factor with a default |
| 1094 | + :class:`CholeskyCovTransform`, so this helper is no longer needed and is |
| 1095 | + a thin shim around it for backward compatibility. |
1049 | 1096 |
|
1050 | 1097 | Parameters |
1051 | 1098 | ---------- |
1052 | 1099 | S : ndarray |
1053 | | - p x p positive definite matrix |
1054 | | - Or: |
1055 | | - p x p lower-triangular matrix that is the Cholesky factor |
1056 | | - of the covariance matrix. |
| 1100 | + ``(p, p)`` positive-definite scale matrix, or its lower-triangular |
| 1101 | + Cholesky factor when ``is_cholesky=True``. |
1057 | 1102 | nu : tensor_like of int |
1058 | | - Degrees of freedom, > dim(S). |
1059 | | - is_cholesky : bool, default=False |
1060 | | - Input matrix S is already Cholesky decomposed as S.T * S |
1061 | | - return_cholesky : bool, default=False |
1062 | | - Only return the Cholesky decomposed matrix. |
1063 | | - initval : ndarray |
1064 | | - p x p positive definite matrix used to initialize |
1065 | | -
|
1066 | | - Notes |
1067 | | - ----- |
1068 | | - This is not a standard Distribution class but follows a similar |
1069 | | - interface. Besides the Wishart distribution, it will add RVs |
1070 | | - name_c and name_z to your model which make up the matrix. |
1071 | | -
|
1072 | | - This distribution is usually a bad idea to use as a prior for multivariate |
1073 | | - normal. You should instead use LKJCholeskyCov or LKJCorr. |
| 1103 | + Degrees of freedom, > ``dim(S) - 1``. |
| 1104 | + is_cholesky : bool, default False |
| 1105 | + If True, ``S`` is interpreted as the Cholesky factor of the scale matrix |
| 1106 | + (mapped to :class:`Wishart`'s ``scale_chol`` argument). |
| 1107 | + return_cholesky : bool, default False |
| 1108 | + If True, return the Cholesky factor of the Wishart draw rather than the |
| 1109 | + SPD matrix itself. |
| 1110 | + initval : ndarray, optional |
| 1111 | + Forwarded to :class:`Wishart`. Note that ``initval`` semantics changed: |
| 1112 | + the new :class:`Wishart` expects an SPD matrix value, not the legacy |
| 1113 | + Bartlett component initvals. |
1074 | 1114 | """ |
1075 | | - L = S if is_cholesky else scipy.linalg.cholesky(S) |
1076 | | - diag_idx = np.diag_indices_from(S) |
1077 | | - tril_idx = np.tril_indices_from(S, k=-1) |
1078 | | - n_diag = len(diag_idx[0]) |
1079 | | - n_tril = len(tril_idx[0]) |
1080 | | - |
1081 | | - if initval is not None: |
1082 | | - # Inverse transform |
1083 | | - initval = np.dot(np.dot(np.linalg.inv(L), initval), np.linalg.inv(L.T)) |
1084 | | - initval = scipy.linalg.cholesky(initval, lower=True) |
1085 | | - diag_testval = initval[diag_idx] ** 2 |
1086 | | - tril_testval = initval[tril_idx] |
1087 | | - else: |
1088 | | - diag_testval = None |
1089 | | - tril_testval = None |
1090 | | - |
1091 | | - c = pt.sqrt( |
1092 | | - ChiSquared(f"{name}_c", nu - np.arange(2, 2 + n_diag), shape=n_diag, initval=diag_testval) |
| 1115 | + warnings.warn( |
| 1116 | + "WishartBartlett is deprecated and will be removed in a future release. " |
| 1117 | + "Use pm.Wishart directly. For `is_cholesky=True`, pass `scale_chol=S`. " |
| 1118 | + "For `return_cholesky=True`, wrap pm.Wishart in pt.linalg.cholesky as a " |
| 1119 | + "Deterministic.", |
| 1120 | + FutureWarning, |
| 1121 | + stacklevel=2, |
1093 | 1122 | ) |
1094 | | - pm._log.info(f"Added new variable {name}_c to model diagonal of Wishart.") |
1095 | | - z = Normal(f"{name}_z", 0.0, 1.0, shape=n_tril, initval=tril_testval) |
1096 | | - pm._log.info(f"Added new variable {name}_z to model off-diagonals of Wishart.") |
1097 | | - # Construct A matrix |
1098 | | - A = pt.zeros(S.shape, dtype=np.float32) |
1099 | | - A = pt.set_subtensor(A[diag_idx], c) |
1100 | | - A = pt.set_subtensor(A[tril_idx], z) |
1101 | | - |
1102 | | - # L * A * A.T * L.T ~ Wishart(L*L.T, nu) |
| 1123 | + scale_kwargs = {"scale_chol": S} if is_cholesky else {"V": S} |
1103 | 1124 | if return_cholesky: |
1104 | | - return pm.Deterministic(name, pt.dot(L, A)) |
1105 | | - else: |
1106 | | - return pm.Deterministic(name, pt.dot(pt.dot(pt.dot(L, A), A.T), L.T)) |
| 1125 | + wishart_rv = Wishart(f"_{name}_wishart", nu=nu, initval=initval, **scale_kwargs) |
| 1126 | + return pm.Deterministic(name, pt.linalg.cholesky(wishart_rv)) |
| 1127 | + return Wishart(name, nu=nu, initval=initval, **scale_kwargs) |
1107 | 1128 |
|
1108 | 1129 |
|
1109 | 1130 | def _lkj_normalizing_constant(eta, n): |
|
0 commit comments