Skip to content

Commit 3c7f8f9

Browse files
committed
Invert matricize/matricizeop: matricizeop is primary, matricize wraps it
1 parent d37aebf commit 3c7f8f9

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

src/matricize.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,14 @@ function matricize(
165165
)
166166
return matricize(FusionStyle(a), a, perm_codomain, perm_domain)
167167
end
168-
# This is a more advanced version to overload where the permutation is actually performed.
168+
# Thin wrapper around `matricizeop` with identity op — the actual matricization logic
169+
# (and the fusion-style overload point for folding ops into matricization) lives in
170+
# `matricizeop`.
169171
function matricize(
170172
style::FusionStyle, a::AbstractArray,
171173
perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}
172174
)
173-
ndims(a) == length(perm_codomain) + length(perm_domain) ||
174-
throw(ArgumentError("Invalid bipermutation"))
175-
a_perm = bipermutedims(a, perm_codomain, perm_domain)
176-
return matricize(style, a_perm, Val(length(perm_codomain)))
175+
return matricizeop(style, identity, a, perm_codomain, perm_domain)
177176
end
178177

179178
# Process inputs such as `EllipsisNotation.Ellipsis`.
@@ -238,11 +237,17 @@ function matricizeop(
238237
)
239238
return matricizeop(style, op, a, to_permblocks(a, (perm_codomain, perm_domain))...)
240239
end
240+
# This is the primary function that should be overloaded for new fusion styles to fold
241+
# ops into matricization (e.g., fuse `conj` into the permutation copy, or use lazy
242+
# wrappers like StridedView with op metadata for zero-copy).
241243
function matricizeop(
242244
style::FusionStyle, op, a::AbstractArray,
243245
perm_codomain::Tuple{Vararg{Int}}, perm_domain::Tuple{Vararg{Int}}
244246
)
245-
m = matricize(style, a, perm_codomain, perm_domain)
247+
ndims(a) == length(perm_codomain) + length(perm_domain) ||
248+
throw(ArgumentError("Invalid bipermutation"))
249+
a_perm = bipermutedims(a, perm_codomain, perm_domain)
250+
m = matricize(style, a_perm, Val(length(perm_codomain)))
246251
return _apply_op(op, m)
247252
end
248253

0 commit comments

Comments
 (0)