Skip to content

Commit b6f824a

Browse files
mtfishmanclaude
andcommitted
Refactor initialize_cache to one(similar_operator(...)) form
Replace the dense `identity_map` helper with two composable primitives: * `similar_operator(prototype, codomain_axes)` — undef `NamedDimsOperator` with codomain = input axes, domain = same axes fresh-renamed. Backend / eltype propagates from `prototype` via `Base.similar`. * `Base.one(::AbstractNamedDimsOperator)` — identity operator via matricize → fill with `I` → unmatricize → rewrap. `initialize_cache` reduces to `state(one(similar_operator(factor, linkaxes(iterate, edge))))` per edge. Whitelist `Base.one` in `test_aqua.jl` as a stand-in extension that will move upstream into NDA's `MATRIX_FUNCTIONS` operator-extensions loop. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent 1b97eb0 commit b6f824a

3 files changed

Lines changed: 37 additions & 49 deletions

File tree

src/apply/apply_operators.jl

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import TensorAlgebra as TA
55
using Base: @kwdef
66
using Graphs: dst, src, vertices
77
using LinearAlgebra: norm
8-
using NamedDimsArrays:
9-
AbstractNamedDimsArray, dimnames, domainnames, nameddims, randname, replacedimnames
8+
using NamedDimsArrays: AbstractNamedDimsArray, dimnames, domainnames, nameddims, randname,
9+
replacedimnames, state
1010
using NamedGraphs.GraphsExtensions: all_edges, boundary_edges
1111

1212
# === NestedAlgorithm framework ===
@@ -159,20 +159,13 @@ function AI.initialize_state!(
159159
return state
160160
end
161161

162-
# Identity-message cache: trivial Vidal-gauge initialization where each bond
163-
# carries the identity 2-leg map (= √I = I, in sqrt-message form). Stored
164-
# in a `SqrtMessageCache` so the BP simple update knows to use the messages
165-
# as gauge-in factors directly and skip the √ step.
162+
# Initialize the BP message cache to identity square-root messages.
166163
function initialize_cache(
167-
problem::ApplyOperatorProblem, ::BPApplyGate, iterate::AbstractTensorNetwork
164+
::ApplyOperatorProblem, ::BPApplyGate, iterate::AbstractTensorNetwork
168165
)
169-
T = eltype(iterate[first(vertices(iterate))])
170166
return sqrtmessagecache(all_edges(iterate)) do edge
171-
bond_name = only(linknames(iterate, edge))
172-
bond_axis = only(linkaxes(iterate, edge))
173-
fresh_name = randname(bond_name)
174-
A = identity_map(T, (bond_axis,), (bond_axis,))
175-
return nameddims(A, (fresh_name, bond_name))
167+
factor = iterate[dst(edge)]
168+
return state(one(similar_operator(factor, linkaxes(iterate, edge))))
176169
end
177170
end
178171

src/apply/tensoralgebra.jl

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
import MatrixAlgebraKit as MAK
2323
import TensorAlgebra as TA
2424
using LinearAlgebra: I
25-
using NamedDimsArrays: AbstractNamedDimsArray, denamed, dimnames, name, nameddims, randname
25+
using NamedDimsArrays: AbstractNamedDimsArray, AbstractNamedDimsOperator, codomainnames,
26+
denamed, dimnames, domainnames, name, nameddims, operator, randname, setname, state
2627

2728
# === N-d / TensorAlgebra layer ===
2829

@@ -69,37 +70,26 @@ function MAK.inv_regularized(
6970
return MAK.inv_regularized(a, codomain_names, domain_names; kwargs...)
7071
end
7172

72-
# === identity_map ===
73-
#
74-
# 2k-leg identity *map* (pairwise δ per (co_i, dom_i)):
75-
# `I_{co_1, dom_1} ⊗ … ⊗ I_{co_k, dom_k}` reshaped to a 2k-leg tensor.
76-
#
77-
# Local stand-in: dense-only. Eventual home is `TensorAlgebra.jl` with
78-
# an `AbstractNamedDimsArray` overload and axis-type dispatch for the
79-
# graded / FusionTensor specializations (see
80-
# `gate_application/Overview.md` in `ITensorDevelopmentPlans`).
81-
82-
function identity_map(::Type{T}, codomain_axes, domain_axes) where {T}
73+
function similar_operator(prototype::AbstractNamedDimsArray, codomain_axes)
8374
co_axes = Tuple(codomain_axes)
84-
dom_axes = Tuple(domain_axes)
85-
co_lens = length.(co_axes)
86-
dom_lens = length.(dom_axes)
87-
n_co = prod(co_lens; init = 1)
88-
n_dom = prod(dom_lens; init = 1)
89-
return reshape(Matrix{T}(I, n_co, n_dom), (co_lens..., dom_lens...))
75+
dom_axes = setname.(co_axes, randname.(name.(co_axes)))
76+
A = similar(denamed(prototype), (co_axes..., dom_axes...))
77+
return operator(A, collect(name.(co_axes)), collect(name.(dom_axes)))
9078
end
9179

92-
# Note: the BP simple-update `√S` split uses NDA's existing
93-
# `Base.sqrt(::AbstractNamedDimsArray, codomain_dimnames,
94-
# domain_dimnames)` (matrix sqrt as a single named array) directly,
95-
# combined with explicit `replacedimnames` at the call site to split
96-
# the result into two factors sharing a fresh bond. See the comment in
97-
# `apply_gate_bp_nsite!` (Val{2} method) for the call-site
98-
# choreography. A tuple-returning `factorize_sqrt` primitive — splitting
99-
# a Hermitian PSD `M` into `(X, Y)` with a fresh shared bond — was
100-
# previously staged here as a local stand-in but isn't needed for the
101-
# current `√S` use case (K=1 codomain). It can be reintroduced when a
102-
# multi-codomain (K>1) factorization use case lands, alongside the
103-
# rest of the `factorize_<backend>` family
104-
# (`factorize_balanced_eigh`, `factorize_cholesky`) discussed in
105-
# `gate_application/Overview.md` in `ITensorDevelopmentPlans`.
80+
function Base.one(a::AbstractNamedDimsOperator)
81+
co = codomainnames(a)
82+
dom = domainnames(a)
83+
A = state(a)
84+
A_denamed = denamed(A)
85+
style = TA.FusionStyle(A_denamed)
86+
ndims_co = Val(length(co))
87+
A_mat = TA.matricize(style, A_denamed, ndims_co)
88+
id_mat = similar(A_mat)
89+
copyto!(id_mat, I)
90+
biperm = TA.trivialbiperm(ndims_co, Val(ndims(A_denamed)))
91+
co_axes, dom_axes = TA.blocks(axes(A_denamed)[biperm])
92+
id_denamed = TA.unmatricize(style, id_mat, co_axes, dom_axes)
93+
id_nda = nameddims(id_denamed, dimnames(A))
94+
return operator(id_nda, co, dom)
95+
end

test/test_aqua.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@ using MatrixAlgebraKit: MatrixAlgebraKit
44
using Test: @testset
55

66
@testset "Code quality (Aqua.jl)" begin
7-
# `MatrixAlgebraKit.inv_regularized` is locally extended for
8-
# `AbstractNamedDimsArray` as a stand-in until the corresponding method
9-
# moves into `NamedDimsArrays.jl`. Whitelist it for the piracy check.
7+
# Stand-in Base / MAK extensions on `AbstractNamedDimsArray` /
8+
# `AbstractNamedDimsOperator` that will move upstream into
9+
# `NamedDimsArrays.jl` (or its operator extensions). Whitelist them
10+
# for the piracy check until the upstream PRs land:
11+
# * `MAK.inv_regularized` — N-d pseudo-inverse for named arrays.
12+
# * `Base.one` on `AbstractNamedDimsOperator` — identity operator,
13+
# analog of the existing `Base.sqrt` / `Base.exp` / … extensions
14+
# already defined in NDA's `MATRIX_FUNCTIONS` loop.
1015
Aqua.test_all(
1116
ITensorNetworksNext;
1217
persistent_tasks = false,
13-
piracies = (; treat_as_own = [MatrixAlgebraKit.inv_regularized])
18+
piracies = (; treat_as_own = [MatrixAlgebraKit.inv_regularized, Base.one])
1419
)
1520
end

0 commit comments

Comments
 (0)