Skip to content

Commit 32936fc

Browse files
committed
Route matricizeop through new bipermutedimsop/bipermutedimsopadd!
1 parent 3c7f8f9 commit 32936fc

1 file changed

Lines changed: 40 additions & 7 deletions

File tree

src/matricize.jl

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,50 @@ function matricize_axes(a::AbstractArray, ndims_codomain::Val)
123123
return matricize_axes(FusionStyle(a), a, ndims_codomain)
124124
end
125125

126+
# `bipermutedimsopadd!` / `bipermutedimsop` — bipermutation versions of
127+
# `permutedimsopadd!` / `permuteddims` with an element-wise op folded in.
128+
#
129+
# These are intended to become the primary overload points for downstream array
130+
# types that want to fold ops into a bipartitioned permutation copy (e.g., fuse
131+
# `conj` into the copy, or use lazy wrappers like `StridedView` with op metadata).
132+
# For now, `bipermutedimsopadd!` delegates to the flat-permutation `permutedimsopadd!`.
133+
# In a future PR, the dependency will flip so that `permutedimsopadd!` wraps
134+
# `bipermutedimsopadd!`.
135+
136+
"""
137+
bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, α, β)
138+
139+
Like `permutedimsopadd!`, but takes a bipartitioned permutation
140+
`(perm_codomain, perm_domain)`.
141+
"""
142+
function bipermutedimsopadd!(
143+
dest::AbstractArray, op, src::AbstractArray,
144+
perm_codomain, perm_domain,
145+
α::Number, β::Number
146+
)
147+
return permutedimsopadd!(dest, op, src, (perm_codomain..., perm_domain...), α, β)
148+
end
149+
150+
"""
151+
bipermutedimsop(op, src, perm_codomain, perm_domain)
152+
153+
Non-mutating version of `bipermutedimsopadd!`: returns
154+
`op.(permutedims(src, (perm_codomain..., perm_domain...)))`. Has "maybe alias"
155+
semantics — the result may be a view/wrapper aliasing `src` or a fresh copy.
156+
"""
157+
function bipermutedimsop(op, src::AbstractArray, perm_codomain, perm_domain)
158+
perm = (perm_codomain..., perm_domain...)
159+
dest = similar(src, map(i -> size(src, i), perm))
160+
return bipermutedimsopadd!(dest, op, src, perm_codomain, perm_domain, true, false)
161+
end
162+
126163
# Inner version takes a list of sub-permutations, overload this one if needed.
127164
# TODO: Remove _permutedims once support for Julia 1.10 is dropped
128165
# define permutedims with a BlockedPermuation. Default is to flatten it.
129166
# TODO: Deprecate `permuteblockeddims` in favor of `bipermutedims`.
130167
# Keeping it here for backwards compatibility.
131168
function bipermutedims(a::AbstractArray, perm1, perm2)
132-
return _permutedims(a, (perm1..., perm2...))
169+
return bipermutedimsop(identity, a, perm1, perm2)
133170
end
134171
function bipermutedims!(a_dest::AbstractArray, a_src::AbstractArray, perm1, perm2)
135172
return _permutedims!(a_dest, a_src, (perm1..., perm2...))
@@ -246,14 +283,10 @@ function matricizeop(
246283
)
247284
ndims(a) == length(perm_codomain) + length(perm_domain) ||
248285
throw(ArgumentError("Invalid bipermutation"))
249-
a_perm = bipermutedims(a, perm_codomain, perm_domain)
250-
m = matricize(style, a_perm, Val(length(perm_codomain)))
251-
return _apply_op(op, m)
286+
a_perm_op = bipermutedimsop(op, a, perm_codomain, perm_domain)
287+
return matricize(style, a_perm_op, Val(length(perm_codomain)))
252288
end
253289

254-
_apply_op(::typeof(identity), m::AbstractMatrix) = m
255-
_apply_op(op, m::AbstractMatrix) = op.(m)
256-
257290
# ==================================== unmatricize =======================================
258291
# This is the primary function that should be overloaded for new fusion styles.
259292
function unmatricize(

0 commit comments

Comments
 (0)