@@ -123,13 +123,50 @@ function matricize_axes(a::AbstractArray, ndims_codomain::Val)
123123 return matricize_axes (FusionStyle (a), a, ndims_codomain)
124124end
125125
126+ # `bipermutedimsopadd!` / `bipermutedimsop` — bipermutation versions of
127+ # `permutedimsopadd!` / `permuteddims` with an element-wise op folded in.
128+ #
129+ # These are intended to become the primary overload points for downstream array
130+ # types that want to fold ops into a bipartitioned permutation copy (e.g., fuse
131+ # `conj` into the copy, or use lazy wrappers like `StridedView` with op metadata).
132+ # For now, `bipermutedimsopadd!` delegates to the flat-permutation `permutedimsopadd!`.
133+ # In a future PR, the dependency will flip so that `permutedimsopadd!` wraps
134+ # `bipermutedimsopadd!`.
135+
136+ """
137+ bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β)
138+
139+ Like `permutedimsopadd!`, but takes a bipartitioned permutation
140+ `(perm_codomain, perm_domain)`.
141+ """
142+ function bipermutedimsopadd! (
143+ dest:: AbstractArray , op, src:: AbstractArray ,
144+ perm_codomain, perm_domain,
145+ α:: Number , β:: Number
146+ )
147+ return permutedimsopadd! (dest, op, src, (perm_codomain... , perm_domain... ), α, β)
148+ end
149+
150+ """
151+ bipermutedimsop(op, src, perm_codomain, perm_domain)
152+
153+ Non-mutating version of `bipermutedimsopadd!`: returns
154+ `op.(permutedims(src, (perm_codomain..., perm_domain...)))`. Has "maybe alias"
155+ semantics — the result may be a view/wrapper aliasing `src` or a fresh copy.
156+ """
157+ function bipermutedimsop (op, src:: AbstractArray , perm_codomain, perm_domain)
158+ perm = (perm_codomain... , perm_domain... )
159+ dest = similar (src, map (i -> size (src, i), perm))
160+ return bipermutedimsopadd! (dest, op, src, perm_codomain, perm_domain, true , false )
161+ end
162+
126163# Inner version takes a list of sub-permutations, overload this one if needed.
127164# TODO : Remove _permutedims once support for Julia 1.10 is dropped
128165# define permutedims with a BlockedPermuation. Default is to flatten it.
129166# TODO : Deprecate `permuteblockeddims` in favor of `bipermutedims`.
130167# Keeping it here for backwards compatibility.
131168function bipermutedims (a:: AbstractArray , perm1, perm2)
132- return _permutedims ( a, ( perm1... , perm2... ) )
169+ return bipermutedimsop (identity, a, perm1, perm2)
133170end
134171function bipermutedims! (a_dest:: AbstractArray , a_src:: AbstractArray , perm1, perm2)
135172 return _permutedims! (a_dest, a_src, (perm1... , perm2... ))
@@ -246,14 +283,10 @@ function matricizeop(
246283 )
247284 ndims (a) == length (perm_codomain) + length (perm_domain) ||
248285 throw (ArgumentError (" Invalid bipermutation" ))
249- a_perm = bipermutedims (a, perm_codomain, perm_domain)
250- m = matricize (style, a_perm, Val (length (perm_codomain)))
251- return _apply_op (op, m)
286+ a_perm_op = bipermutedimsop (op, a, perm_codomain, perm_domain)
287+ return matricize (style, a_perm_op, Val (length (perm_codomain)))
252288end
253289
254- _apply_op (:: typeof (identity), m:: AbstractMatrix ) = m
255- _apply_op (op, m:: AbstractMatrix ) = op .(m)
256-
257290# ==================================== unmatricize =======================================
258291# This is the primary function that should be overloaded for new fusion styles.
259292function unmatricize (
0 commit comments