Skip to content

Commit c8e19ff

Browse files
committed
temporary fix for TensorKit
1 parent a099ef4 commit c8e19ff

1 file changed

Lines changed: 74 additions & 0 deletions

File tree

src/tensors/tensoroperations.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,80 @@ function TO.tensoralloc(
154154
return C
155155
end
156156

157+
# unfortunate overlaod until TK fix
158+
function TK.blas_contract!(
159+
C::AbstractBlockTensorMap,
160+
A::AbstractBlockTensorMap, pA::Index2Tuple,
161+
B::AbstractBlockTensorMap, pB::Index2Tuple,
162+
pAB::Index2Tuple, α, β,
163+
backend, allocator
164+
)
165+
bstyle = BraidingStyle(sectortype(C))
166+
bstyle isa SymmetricBraiding ||
167+
throw(SectorMismatch("only tensors with symmetric braiding rules can be contracted; try `@planar` instead"))
168+
TC = scalartype(C)
169+
170+
# check which tensors have to be permuted/copied
171+
copyA = !(TO.isblascontractable(A, pA) && eltype(A) === TC)
172+
copyB = !(TO.isblascontractable(B, pB) && eltype(B) === TC)
173+
174+
if bstyle isa Fermionic && any(isdual Base.Fix1(space, B), pB[1])
175+
# twist smallest object if neither or both already have to be permuted
176+
# otherwise twist the one that already is copied
177+
if !(copyA copyB)
178+
twistA = dim(A) < dim(B)
179+
else
180+
twistA = copyA
181+
end
182+
twistB = !twistA
183+
copyA |= twistA
184+
copyB |= twistB
185+
else
186+
twistA = false
187+
twistB = false
188+
end
189+
190+
# Bring A in the correct form for BLAS contraction
191+
if copyA
192+
Anew = TO.tensoralloc_add(TC, A, pA, false, Val(true), allocator)
193+
Anew = TO.tensoradd!(Anew, A, pA, false, One(), Zero(), backend, allocator)
194+
twistA && twist!(Anew, filter(!isdual Base.Fix1(space, Anew), domainind(Anew)))
195+
else
196+
Anew = permute(A, pA)
197+
end
198+
pAnew = (codomainind(Anew), domainind(Anew))
199+
200+
# Bring B in the correct form for BLAS contraction
201+
if copyB
202+
Bnew = TO.tensoralloc_add(TC, B, pB, false, Val(true), allocator)
203+
Bnew = TO.tensoradd!(Bnew, B, pB, false, One(), Zero(), backend, allocator)
204+
twistB && twist!(Bnew, filter(isdual Base.Fix1(space, Bnew), codomainind(Bnew)))
205+
else
206+
Bnew = permute(B, pB)
207+
end
208+
pBnew = (codomainind(Bnew), domainind(Bnew))
209+
210+
# Bring C in the correct form for BLAS contraction
211+
ipAB = TO.oindABinC(pAB, pAnew, pBnew)
212+
copyC = !TO.isblasdestination(C, ipAB)
213+
214+
if copyC
215+
Cnew = TO.tensoralloc_add(TC, C, ipAB, false, Val(true), allocator)
216+
mul!(Cnew, Anew, Bnew)
217+
TO.tensoradd!(C, Cnew, pAB, false, α, β, backend, allocator)
218+
TO.tensorfree!(Cnew, allocator)
219+
else
220+
Cnew = permute(C, ipAB)
221+
mul!(Cnew, Anew, Bnew, α, β)
222+
end
223+
224+
copyA && TO.tensorfree!(Anew, allocator)
225+
copyB && TO.tensorfree!(Bnew, allocator)
226+
227+
return C
228+
end
229+
230+
157231
# tensorfree!
158232
# -----------
159233
function TO.tensorfree!(t::BlockTensorMap, allocator = TO.DefaultAllocator())

0 commit comments

Comments
 (0)