@@ -165,15 +165,14 @@ function matricize(
165165 )
166166 return matricize (FusionStyle (a), a, perm_codomain, perm_domain)
167167end
168- # This is a more advanced version to overload where the permutation is actually performed.
168+ # Thin wrapper around `matricizeop` with identity op — the actual matricization logic
169+ # (and the fusion-style overload point for folding ops into matricization) lives in
170+ # `matricizeop`.
169171function matricize (
170172 style:: FusionStyle , a:: AbstractArray ,
171173 perm_codomain:: Tuple{Vararg{Int}} , perm_domain:: Tuple{Vararg{Int}}
172174 )
173- ndims (a) == length (perm_codomain) + length (perm_domain) ||
174- throw (ArgumentError (" Invalid bipermutation" ))
175- a_perm = bipermutedims (a, perm_codomain, perm_domain)
176- return matricize (style, a_perm, Val (length (perm_codomain)))
175+ return matricizeop (style, identity, a, perm_codomain, perm_domain)
177176end
178177
179178# Process inputs such as `EllipsisNotation.Ellipsis`.
@@ -238,11 +237,17 @@ function matricizeop(
238237 )
239238 return matricizeop (style, op, a, to_permblocks (a, (perm_codomain, perm_domain))... )
240239end
240+ # This is the primary function that should be overloaded for new fusion styles to fold
241+ # ops into matricization (e.g., fuse `conj` into the permutation copy, or use lazy
242+ # wrappers like StridedView with op metadata for zero-copy).
241243function matricizeop (
242244 style:: FusionStyle , op, a:: AbstractArray ,
243245 perm_codomain:: Tuple{Vararg{Int}} , perm_domain:: Tuple{Vararg{Int}}
244246 )
245- m = matricize (style, a, perm_codomain, perm_domain)
247+ ndims (a) == length (perm_codomain) + length (perm_domain) ||
248+ throw (ArgumentError (" Invalid bipermutation" ))
249+ a_perm = bipermutedims (a, perm_codomain, perm_domain)
250+ m = matricize (style, a_perm, Val (length (perm_codomain)))
246251 return _apply_op (op, m)
247252end
248253
0 commit comments