Skip to content

Commit 1b97eb0

Browse files
mtfishmanclaude
andcommitted
Refactor Val{2} √S split via sqrt(S, co, dom) + replacedimnames
Drop the local `balanced_eigh_factorization` stand-in in favor of using NamedDimsArrays' existing `Base.sqrt(::NDA, codomain, domain)` (single matrix-sqrt named array) directly, splitting the result into two factors at the call site via `replacedimnames`. The "transposition-via-relabel" on `cache![v1 => v2]` (swap the codomain/domain name slots, then fresh) ensures each directed sqrt-message has the correct arrow direction on its matching leg; for dense backings sqrt_S equals its transpose so the swap is numerically a no-op, but the distinction matters for graded / fermionic axes. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 36e957c commit 1b97eb0

2 files changed

Lines changed: 47 additions & 67 deletions

File tree

src/apply/apply_operators.jl

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -268,22 +268,42 @@ function apply_gate_bp_nsite!(
268268
S = S / norm(S)
269269
end
270270
name_v1, name_v2 = dimnames(S)
271-
sqrt_S_v1, sqrt_S_v2 = balanced_eigh_factorization(S, (name_v1,), (name_v2,))
272-
R_v1 = U_v1 * sqrt_S_v1
273-
R_v2 = sqrt_S_v2 * U_v2
271+
# `sqrt(S, (name_v1,), (name_v2,))` is NDA's matrix sqrt of `S` —
272+
# a single 2-leg named array with dimnames `(name_v1, name_v2)`
273+
# satisfying `sqrt_S * sqrt_S ≈ S` in the matrix algebra (each
274+
# `sqrt_S` factor contracts on one of `S`'s legs). Eventual endpoint:
275+
# 1-arg `sqrt(S)` once `TA.svd` returns `S` as a `NamedDimsOperator`.
276+
sqrt_S = sqrt(S, (name_v1,), (name_v2,))
277+
# Build R factors by absorbing `sqrt_S` on each side; the rebind on
278+
# the v1 side picks `name_v1` as the new shared bond between
279+
# `dest[v1]` and `dest[v2]`. With a `NamedDimsOperator` wrapper, the
280+
# rebind becomes `apply(sqrt_S, U_v1)`.
281+
R_v1 = replacedimnames(U_v1 * sqrt_S, name_v2 => name_v1)
282+
R_v2 = sqrt_S * U_v2
274283

275284
dest[v1] = prod([[Q_v1 * R_v1]; inv_sqrt_envs_v1])
276285
dest[v2] = prod([[Q_v2 * R_v2]; inv_sqrt_envs_v2])
277286

278-
# Reuse the two `sqrt_S` factors as new sqrt-messages, rebinding the
279-
# outer (SVD-codomain / SVD-domain) leg to a fresh name per directed
280-
# edge so the two messages don't share a leg name. Each direction
281-
# picks the factor whose shared-bond arrow contracts with the
282-
# receiving tensor: `sqrt_S_v1`'s bond arrow contracts with `dest[v2]`
283-
# (v1 => v2), `sqrt_S_v2`'s with `dest[v1]` (v2 => v1). For dense
284-
# backings the two factors carry the same data and the choice is
285-
# invisible; the distinction matters for graded / fermionic axes.
286-
cache![v1 => v2] = replacedimnames(sqrt_S_v1, name_v1 => randname(name_v1))
287-
cache![v2 => v1] = replacedimnames(sqrt_S_v2, name_v2 => randname(name_v2))
287+
# Both directed sqrt-messages derive from the same `sqrt_S`, but
288+
# with different name-slot choices so each message's "matching" leg
289+
# (name_v1, contracting with the receiving tensor) carries the
290+
# correct arrow direction.
291+
#
292+
# `dest[v1]`'s name_v1 bond inherits the domain-side arrow of `S`
293+
# (from the `name_v2 => name_v1` rebind in `R_v1`), and `dest[v2]`'s
294+
# name_v1 bond inherits the codomain-side arrow (from `sqrt_S * U_v2`).
295+
# So:
296+
# * `cache![v2 => v1]`'s matching leg needs the codomain-side arrow
297+
# → use sqrt_S's name_v1 leg directly; relabel name_v2 to fresh.
298+
# * `cache![v1 => v2]`'s matching leg needs the domain-side arrow
299+
# → swap roles: rename sqrt_S's name_v2 to name_v1, and the
300+
# original name_v1 (now the internal-rank slot) to a fresh name.
301+
# For dense backings sqrt_S equals its transpose, so the two choices
302+
# coincide numerically; the distinction matters for graded /
303+
# fermionic axes.
304+
cache![v1 => v2] = replacedimnames(
305+
sqrt_S, name_v1 => randname(name_v1), name_v2 => name_v1
306+
)
307+
cache![v2 => v1] = replacedimnames(sqrt_S, name_v2 => randname(name_v2))
288308
return dest
289309
end

src/apply/tensoralgebra.jl

Lines changed: 14 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -89,57 +89,17 @@ function identity_map(::Type{T}, codomain_axes, domain_axes) where {T}
8989
return reshape(Matrix{T}(I, n_co, n_dom), (co_lens..., dom_lens...))
9090
end
9191

92-
# === balanced_eigh_factorization ===
93-
#
94-
# Balanced eigh-based factorization of a Hermitian PSD named array `a`:
95-
# returns `(X, Y)` with `X * Y ≈ a` via named contraction, sharing a
96-
# fresh-named bond. For k-codomain input, `X` has names
97-
# `(codomain..., new_bond)` and `Y` has names `(new_bond, domain...)`.
98-
#
99-
# Conceptually: `a = U Λ U†` via eigh, then split Λ = √Λ · √Λ symmetrically
100-
# between the two halves so `X = U √Λ` and `Y = √Λ U†`. For
101-
# diagonal-Hermitian-PSD input (the BP simple-update SVD-`S` case),
102-
# eigh is trivial and this reduces to the per-element √ split.
103-
#
104-
# Layered through `TA.matricize` → matrix `sqrt` → `TA.unmatricize`,
105-
# matching the shape of `inv_regularized` above. The N-d / TA layer
106-
# is namespaced locally (intended `TA.balanced_eigh_factorization`),
107-
# the named layer extends here. See `gate_application/Overview.md` in
108-
# `ITensorDevelopmentPlans` for the operator-design synthesis this
109-
# slots into (`balanced_eigh_factor` single-factor companion,
110-
# `cholesky_factor`, `positive_factor` umbrella).
111-
112-
function balanced_eigh_factorization(
113-
style::TA.FusionStyle, A::AbstractArray, ndims_codomain::Val
114-
)
115-
M = TA.matricize(style, A, ndims_codomain)
116-
sqrtM = sqrt(M)
117-
biperm = TA.trivialbiperm(ndims_codomain, Val(ndims(A)))
118-
axes_codomain, axes_domain = TA.blocks(axes(A)[biperm])
119-
bond_axis = axes(sqrtM, 2)
120-
return (
121-
TA.unmatricize(style, sqrtM, axes_codomain, (bond_axis,)),
122-
TA.unmatricize(style, sqrtM, (bond_axis,), axes_domain),
123-
)
124-
end
125-
function balanced_eigh_factorization(A::AbstractArray, ndims_codomain::Val)
126-
return balanced_eigh_factorization(TA.FusionStyle(A), A, ndims_codomain)
127-
end
128-
129-
function balanced_eigh_factorization(
130-
a::AbstractNamedDimsArray, codomain_dimnames, domain_dimnames
131-
)
132-
codomain_names = collect(name.(codomain_dimnames))
133-
domain_names = collect(name.(domain_dimnames))
134-
biperm = TA.blockedperm_indexin(
135-
Tuple.((dimnames(a), codomain_names, domain_names))...
136-
)
137-
perm_codomain, perm_domain = TA.blocks(biperm)
138-
A_perm = TA.bipermutedims(denamed(a), perm_codomain, perm_domain)
139-
X_denamed, Y_denamed = balanced_eigh_factorization(A_perm, Val(length(perm_codomain)))
140-
new_bond = randname(first(codomain_names))
141-
return (
142-
nameddims(X_denamed, [codomain_names; [new_bond]]),
143-
nameddims(Y_denamed, [[new_bond]; domain_names]),
144-
)
145-
end
92+
# Note: the BP simple-update `√S` split uses NDA's existing
93+
# `Base.sqrt(::AbstractNamedDimsArray, codomain_dimnames,
94+
# domain_dimnames)` (matrix sqrt as a single named array) directly,
95+
# combined with explicit `replacedimnames` at the call site to split
96+
# the result into two factors sharing a fresh bond. See the comment in
97+
# `apply_gate_bp_nsite!` (Val{2} method) for the call-site
98+
# choreography. A tuple-returning `factorize_sqrt` primitive — splitting
99+
# a Hermitian PSD `M` into `(X, Y)` with a fresh shared bond — was
100+
# previously staged here as a local stand-in but isn't needed for the
101+
# current `√S` use case (K=1 codomain). It can be reintroduced when a
102+
# multi-codomain (K>1) factorization use case lands, alongside the
103+
# rest of the `factorize_<backend>` family
104+
# (`factorize_balanced_eigh`, `factorize_cholesky`) discussed in
105+
# `gate_application/Overview.md` in `ITensorDevelopmentPlans`.

0 commit comments

Comments
 (0)