Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ext/TensorKitChainRulesCoreExt/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,15 @@ function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTe
ipA = (codomainind(A), domainind(A))
pB = (allind(B), ())
dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B)))
tB = twist(B, filter(x -> isdual(space(B, x)), allind(B)))
tB = _twist_nocopy(B, filter(x -> isdual(space(B, x)), allind(B)))
dA = tensorcontract!(dA, ΔC, pΔC, false, tB, pB, true, ipA)
return projectA(dA)
end
dB_ = @thunk let
ipB = (codomainind(B), domainind(B))
pA = ((), allind(A))
dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A)))
tA = twist(A, filter(x -> isdual(space(A, x)), allind(A)))
tA = _twist_nocopy(A, filter(x -> isdual(space(A, x)), allind(A)))
dB = tensorcontract!(dB, tA, pA, true, ΔC, pΔC, false, ipB)
return projectB(dB)
end
Expand Down
13 changes: 10 additions & 3 deletions ext/TensorKitChainRulesCoreExt/tensoroperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function ChainRulesCore.rrule(
# for non-symmetric tensors this might be more efficient like this,
# but for symmetric tensors an intermediate object will anyways be created
# and then it might be more efficient to use an addition and inner product
tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
tΔC = _twist_nocopy(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC)))
_dα = tensorscalar(
tensorcontract(
A, ((), linearize(pA)), !conjA,
Expand Down Expand Up @@ -74,7 +74,7 @@ function ChainRulesCore.rrule(
conjB′ = conjA ? conjB : !conjB
TA = promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))
# TODO: allocator
tB = twist(
tB = _twist_nocopy(
B,
TupleTools.vcat(
filter(x -> !isdual(space(B, x)), pB[1]),
Expand All @@ -99,7 +99,7 @@ function ChainRulesCore.rrule(
conjA′ = conjB ? conjA : !conjA
TB = promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))
# TODO: allocator
tA = twist(
tA = _twist_nocopy(
A,
TupleTools.vcat(
filter(x -> isdual(space(A, x)), pA[1]),
Expand Down Expand Up @@ -188,3 +188,10 @@ function ChainRulesCore.rrule(::typeof(TensorKit.scalar), t::AbstractTensorMap)
end
return val, scalar_pullback
end

# temporary function to avoid copies when not needed
# TODO: remove once `twist(t; copy=false)` is defined
function _twist_nocopy(t, inds; kwargs...)
(BraidingStyle(sectortype(t)) isa Fermionic && !isempty(inds)) || return t
Comment thread
lkdvos marked this conversation as resolved.
Outdated
return twist(t, inds; kwargs...)
end
Loading