Skip to content

Unconstrain transform for Wishart#8246

Open
ricardoV94 wants to merge 2 commits intopymc-devs:v6from
ricardoV94:Wishart
Open

Unconstrain transform for Wishart#8246
ricardoV94 wants to merge 2 commits intopymc-devs:v6from
ricardoV94:Wishart

Conversation

@ricardoV94
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 commented Apr 9, 2026

Similar idea as #7380 (but this is actually simpler). Almost the same rewrite as LKJCholeskyCov, except here we unconstrain to the full dense matrix.

Closes #8196 (it's now usable)

@ricardoV94 ricardoV94 changed the title Unconstraint Wishart Unconstrain Wishart Apr 9, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 9, 2026

Codecov Report

❌ Patch coverage is 90.72165% with 9 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (v6@d77d6c3). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pymc/distributions/multivariate.py 86.95% 9 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@          Coverage Diff          @@
##             v6    #8246   +/-   ##
=====================================
  Coverage      ?   91.77%           
=====================================
  Files         ?      124           
  Lines         ?    20003           
  Branches      ?        0           
=====================================
  Hits          ?    18358           
  Misses        ?     1645           
  Partials      ?        0           
Files with missing lines Coverage Δ
pymc/distributions/transforms.py 100.00% <100.00%> (ø)
pymc/testing.py 90.87% <100.00%> (ø)
pymc/distributions/multivariate.py 95.37% <86.95%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented Apr 9, 2026

Took a look at the compiled logp+dlogp, and we pay some price for the whole matrix construction.

Edit: Graph is pretty clean now (see below), collapsing

Details

For an n × n Wishart the unconstrained vector has length n(n+1)/2, with diagonal positions at the cumulative-sum sequence [0, 2, 5, 9, …] of length n.

Full logp+dlogp graph

# logp
Composite{(2.079441547393799 + (0.5 * ((-14.159198660192542 + (2.0 * i2)) - i1)) + i0)} [id A] 17
 ├─ Sum{axes=None} [id B] 5
 │  └─ Mul [id C] 3
 │     ├─ [4. 3. 2.] [id D]
 │     └─ AdvancedSubtensor1 [id E] 0
 │        ├─ Sigma_cholesky-cov__ [id F]
 │        └─ [0 2 5] [id G]
 ├─ Sum{axes=None} [id H] 15
 │  └─ Sqr [id I] 13
 │     └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] 10
 │        ├─ [[1. 0. 0. ... 0. 0. 1.]] [id K]
 │        └─ AdvancedSetSubtensor [id L] 6
 │           ├─ Alloc [id M] 1
 │           │  ├─ 0.0 [id N]
 │           │  ├─ 3 [id O]
 │           │  └─ 3 [id O]
 │           ├─ AdvancedIncSubtensor1{no_inplace,set} [id P] 4
 │           │  ├─ Sigma_cholesky-cov__ [id F]
 │           │  ├─ Exp [id Q] 2
 │           │  │  └─ AdvancedSubtensor1 [id E] 0
 │           │  │     └─ ···
 │           │  └─ [0 2 5] [id G]
 │           ├─ [0 1 1 2 2 2] [id R]
 │           └─ [0 0 1 0 1 2] [id S]
 └─ Sum{axes=None} [id T] 12
    └─ Log [id U] 9
       └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=True} [id V] 7
          └─ AdvancedSetSubtensor [id L] 6
             └─ ···
# dlogp
AdvancedIncSubtensor1{inplace,inc} [id W] 'Sigma_cholesky-cov___grad' 23
 ├─ AdvancedIncSubtensor1{inplace,set} [id X] 21
 │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id Y] 19
 │  │  ├─ Add [id Z] 18
 │  │  │  ├─ AdvancedSetSubtensor [id BA] 11
 │  │  │  │  ├─ Alloc [id M] 1
 │  │  │  │  │  └─ ···
 │  │  │  │  ├─ Reciprocal [id BB] 8
 │  │  │  │  │  └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=True} [id V] 7
 │  │  │  │  │     └─ ···
 │  │  │  │  ├─ [0 1 2] [id BC]
 │  │  │  │  └─ [0 1 2] [id BC]
 │  │  │  └─ SolveTriangular{unit_diagonal=False, lower=False, b_ndim=2, overwrite_b=True} [id BD] 16
 │  │  │     ├─ [[1. 0. 0. ... 0. 0. 1.]] [id K]
 │  │  │     └─ Neg [id BE] 14
 │  │  │        └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] 10
 │  │  │           └─ ···
 │  │  ├─ [0 1 1 2 2 2] [id R]
 │  │  └─ [0 0 1 0 1 2] [id S]
 │  ├─ [0. 0. 0.] [id BF]
 │  └─ [0 2 5] [id G]
 ├─ Composite{((i1 * i2) + i0)} [id BG] 22
 │  ├─ [4. 3. 2.] [id D]
 │  ├─ AdvancedSubtensor1 [id BH] 20
 │  │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id Y] 19
 │  │  │  └─ ···
 │  │  └─ [0 2 5] [id G]
 │  └─ Exp [id Q] 2
 │     └─ ···
 └─ [0 2 5] [id G]

Inner graphs:

Composite{(2.079441547393799 + (0.5 * ((-14.159198660192542 + (2.0 * i2)) - i1)) + i0)} [id A]
 ← add [id BI]
    ├─ 2.079441547393799 [id BJ]
    ├─ mul [id BK]
    │  ├─ 0.5 [id BL]
    │  └─ sub [id BM]
    │     ├─ add [id BN]
    │     │  ├─ -14.159198660192542 [id BO]
    │     │  └─ mul [id BP]
    │     │     ├─ 2.0 [id BQ]
    │     │     └─ i2 [id BR]
    │     └─ i1 [id BS]
    └─ i0 [id BT]

Composite{((i1 * i2) + i0)} [id BG]
 ← add [id BU]
    ├─ mul [id BV]
    │  ├─ i1 [id BS]
    │  └─ i2 [id BR]
    └─ i0 [id BT]

1. Diagonal gradient routed through an (n, n) scatter

AdvancedSubtensor{idx_list=(0, 1)} [id Y]        ← read length-n(n+1)/2 packed lower-tri
 ├─ Add [id Z]                                    ← add (n,n) matrices
 │  ├─ AdvancedSetSubtensor [id BA]              ← scatter 1/L_kk onto diag of (n,n) zeros
 │  │  ├─ Alloc [id M]                            ← (n,n) zeros
 │  │  ├─ Reciprocal [id BB]                      ← 1/L_kk, length n
 │  │  ├─ [0 1 2]
 │  │  └─ [0 1 2]
 │  └─ SolveTriangular{lower=False} [id BD]       ← full (n,n) −V⁻¹·L gradient term
 ├─ [0 1 1 2 2 2]
 └─ [0 0 1 0 1 2]

The two gradient contributions (1/L_kk on the diagonal, −V⁻¹ L everywhere) are added as full (n, n) matrices, then only the n(n+1)/2 lower-triangular entries are read out. The upper triangle is dead work. The diagonal scatter writes n values into zeros solely so they align with the (n, n) layout of the triangular-solve result.

2. Extracting the diagonal we just placed

ExtractDiag [id V]
 └─ AdvancedSetSubtensor [id L]      ← scatter packed vec into (n,n) zeros
    ├─ Alloc [id M]                   ← (n,n) zeros
    ├─ AdvancedIncSubtensor1 [id P]  ← packed vec with Exp on diag slots
    │  ├─ Sigma_cholesky-cov__ [id F]
    │  ├─ Exp [id Q]                  ← exp(unc[diag_idxs]) = diag(L)
    │  └─ [0 2 5]
    ├─ [0 1 1 2 2 2]
    └─ [0 0 1 0 1 2]

ExtractDiag(L) recovers exactly Exp [id Q], the values we scattered onto the diagonal in the first place. Recognizing this identity simplifies both consumers:

  • logp: Sum(Log(ExtractDiag(L))) = Sum(Log(Exp(unc[diag_idxs]))) = Sum(unc[diag_idxs]).
    The Log, ExtractDiag, and Exp all cancel.

  • dlogp: Reciprocal(ExtractDiag(L)) = 1/Exp(unc[diag_idxs]).
    This is multiplied by L_kk = Exp(unc[diag_idxs]) via the chain rule in the per-diagonal Composite, giving (1/L_kk) · L_kk = 1. The diagonal gradient from the log-det term becomes a constant +1 per diagonal slot, absorbable into the existing log-Jacobian coefficients [n+1, n, …, 2][n+2, n+1, …, 3].

3. Set-then-inc on the same diagonal positions

AdvancedIncSubtensor1{inc} [id W]         ← inc at [0 2 5]
 ├─ AdvancedIncSubtensor1{set} [id X]     ← set [0 2 5] to zero
 │  ├─ [id Y]                              ← length-n(n+1)/2 packed gradient (from §1)
 │  ├─ [0. 0. 0.]
 │  └─ [0 2 5]
 ├─ Composite{(i1 * i2) + i0} [id BG]    ← length-n diagonal contribution
 │  ├─ [4. 3. 2.]                          ← log-Jacobian coefficients
 │  ├─ AdvancedSubtensor1 [id BH]         ← diag slice of [id Y] (read before zeroing)
 │  └─ Exp [id Q]                          ← L_kk
 └─ [0 2 5]

The set zeros the diagonal slots; the inc overwrites them with (Y[diag_idxs] · L_kk) + [4, 3, 2]. The zeroing is an autodiff artifact. With §1–§2 applied, Y[diag_idxs] at the diagonal simplifies (the Reciprocal scatter becomes a constant), and the entire set-then-inc collapses to a single inc_subtensor of one fused length-n vector.

4. Structural lower bound after all three simplifications

AdvancedIncSubtensor1{inc}                ← single inc at packed diag positions
 ├─ AdvancedSubtensor{idx_list=(0, 1)}    ← length-n(n+1)/2 packed lower-tri of −V⁻¹·L
 │  ├─ SolveTriangular{lower=False}       ← (n,n), unchanged
 │  ├─ [0 1 1 2 2 2]
 │  └─ [0 0 1 0 1 2]
 ├─ Composite                             ← length-n fused diagonal contribution
 │  ├─ [n+2, n+1, …, 3]                   ← merged log-Jacobian + log-det constant
 │  └─ Exp(unc[diag_idxs])                ← shared with L construction
 └─ [0 2 5]

What's already good

  • L built once, used three times: forward SolveTriangular, ExtractDiag, and gradient's second triangular solve.
  • Single Alloc: the (n, n) zero buffer is shared between L construction and the §1 diagonal-scatter.
  • Forward solve shared: M = L⁻¹ V serves both ‖M‖²_F (trace term) and −Lᵀ \ M (gradient term).
  • Unconstrained diagonal shared: unc[diag_idxs] and its Exp are CSE'd between L construction and the gradient's chain-rule factor.
  • No Cholesky op: the cholesky_ldotlt rewrite has already eliminated the chol(L Lᵀ) round trip.

@ricardoV94
Copy link
Copy Markdown
Member Author

ricardoV94 commented Apr 9, 2026

With a few general rewrites (already upstreamed in pymc-devs/pytensor#2061), got it down to this form:

import pymc as pm
import numpy as np

rng = np.random.default_rng(1)
A = rng.normal(size=(n, n))
V = A @ A.T + n * np.eye(n)

with pm.Model() as m:
  pm.Wishart("Sigma", nu=4, V=V)

m.logp_dlogp_function()._pytensor_function.dprint(print_shape=True, print_memory_map=True)
Composite{(2.079441547393799 + (0.5 * (-28.789189739493835 - i1)) + i0)} [id A] shape=() d={0: [0]} 12
 ├─ Sum{axes=None} [id B] shape=() 5
 │  └─ Mul [id C] shape=(?,) 3
 │     ├─ [4. 3. 2.] [id D] shape=(3,)
 │     └─ AdvancedSubtensor1 [id E] shape=(3,) 0
 │        ├─ Sigma_cholesky-cov__ [id F] shape=(?,)
 │        └─ [0 2 5] [id G] shape=(3,)
 └─ Sum{axes=None} [id H] shape=() 10
    └─ Sqr [id I] shape=(3, 3) 8
       └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] shape=(3, 3) d={0: [1]} 7
          ├─ [[1.975771 ... 84325469]] [id K] shape=(3, 3)
          └─ AdvancedSetSubtensor [id L] shape=(3, 3) d={0: [0]} 6
             ├─ Alloc [id M] shape=(3, 3) 1
             │  ├─ 0.0 [id N] shape=()
             │  ├─ 3 [id O] shape=()
             │  └─ 3 [id O] shape=()
             ├─ AdvancedIncSubtensor1{no_inplace,set} [id P] shape=(?,) 4
             │  ├─ Sigma_cholesky-cov__ [id F] shape=(?,)
             │  ├─ Exp [id Q] shape=(?,) 2
             │  │  └─ AdvancedSubtensor1 [id E] shape=(3,) 0
             │  │     └─ ···
             │  └─ [0 2 5] [id G] shape=(3,)
             ├─ [0 1 1 2 2 2] [id R] shape=(6,)
             └─ [0 0 1 0 1 2] [id S] shape=(6,)
AdvancedIncSubtensor1{inplace,set} [id T] shape=(6,) 'Sigma_cholesky-cov___grad' d={0: [0]} 16
 ├─ AdvancedSubtensor{idx_list=(0, 1)} [id U] shape=(6,) 13
 │  ├─ SolveTriangular{unit_diagonal=False, lower=False, b_ndim=2, overwrite_b=True} [id V] shape=(3, 3) d={0: [1]} 11
 │  │  ├─ [[1.975771 ... 84325469]] [id W] shape=(3, 3)
 │  │  └─ Neg [id X] shape=(3, 3) d={0: [0]} 9
 │  │     └─ SolveTriangular{unit_diagonal=False, lower=True, b_ndim=2, overwrite_b=True} [id J] shape=(3, 3) d={0: [1]} 7
 │  │        └─ ···
 │  ├─ [0 1 1 2 2 2] [id R] shape=(6,)
 │  └─ [0 0 1 0 1 2] [id S] shape=(6,)
 ├─ Composite{((i1 * i2) + i0)} [id Y] shape=(3,) d={0: [1]} 15
 │  ├─ [4. 3. 2.] [id D] shape=(3,)
 │  ├─ AdvancedSubtensor1 [id Z] shape=(3,) 14
 │  │  ├─ AdvancedSubtensor{idx_list=(0, 1)} [id U] shape=(6,) 13
 │  │  │  └─ ···
 │  │  └─ [0 2 5] [id G] shape=(3,)
 │  └─ Exp [id Q] shape=(?,) 2
 │     └─ ···
 └─ [0 2 5] [id G] shape=(3,)

Inner graphs:

Composite{(2.079441547393799 + (0.5 * (-28.789189739493835 - i1)) + i0)} [id A] d={0: [0]}
 ← add [id BA] shape=()
    ├─ 2.079441547393799 [id BB] shape=()
    ├─ mul [id BC] shape=()
    │  ├─ 0.5 [id BD] shape=()
    │  └─ sub [id BE] shape=()
    │     ├─ -28.789189739493835 [id BF] shape=()
    │     └─ i1 [id BG] shape=()
    └─ i0 [id BH] shape=()

Composite{((i1 * i2) + i0)} [id Y] d={0: [1]}
 ← add [id BI] shape=()
    ├─ mul [id BJ] shape=()
    │  ├─ i1 [id BG] shape=()
    │  └─ i2 [id BK] shape=()
    └─ i0 [id BH] shape=()

So as good as I can think of

@ricardoV94 ricardoV94 changed the title Unconstrain Wishart Unconstrain transform Wishart Apr 19, 2026
@ricardoV94 ricardoV94 marked this pull request as ready for review April 23, 2026 17:52
@ricardoV94 ricardoV94 changed the title Unconstrain transform Wishart Unconstrain transform for Wishart Apr 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DOC: Clarify the pm.Wishart warning (“unusable in a PyMC model”)

1 participant