|
21 | 21 |
|
22 | 22 | # tensoralloc_contract |
23 | 23 | # -------------------- |
24 | | -for TTB in (:AbstractTensorMap, :AbstractBlockTensorMap) |
| 24 | +for TTA in (:AbstractTensorMap, :AbstractBlockTensorMap), TTB in (:AbstractTensorMap, :AbstractBlockTensorMap) |
| 25 | + TTA == TTB == :AbstractTensorMap && continue |
25 | 26 | @eval function TO.tensorcontract_type( |
26 | 27 | TC, |
27 | | - A::AbstractBlockTensorMap, ::Index2Tuple, ::Bool, |
| 28 | + A::$TTA, ::Index2Tuple, ::Bool, |
28 | 29 | B::$TTB, ::Index2Tuple, ::Bool, |
29 | 30 | ::Index2Tuple{N₁, N₂}, |
30 | 31 | ) where {N₁, N₂} |
31 | 32 | S = TK.check_spacetype(A, B) |
32 | 33 | TC′ = TK.promote_permute(TC, sectortype(S)) |
33 | | - ATT = AbstractTensorMap{scalartype(A), spacetype(A), numout(A), numin(A)} |
34 | | - BTT = AbstractTensorMap{scalartype(B), spacetype(B), numout(B), numin(B)} |
35 | | - # handle case with BraidingTensors, so that they assume the backing |
36 | | - # array type of the concrete element type |
37 | | - M = if eltype(A) == ATT |
38 | | - TK.similarstoragetype(B, TC′) |
39 | | - elseif eltype(B) == BTT |
40 | | - TK.similarstoragetype(A, TC′) |
41 | | - else |
42 | | - TK.similarstoragetype(TK.similarstoragetype(A, TC′), TK.similarstoragetype(B, TC′)) |
43 | | - end |
| 34 | + M = TK.promote_storagetype(TK.similarstoragetype(A, TC′), TK.similarstoragetype(B, TC′)) |
44 | 35 | return if issparse(A) && issparse(B) |
45 | 36 | sparseblocktensormaptype(S, N₁, N₂, M) |
46 | 37 | else |
47 | 38 | blocktensormaptype(S, N₁, N₂, M) |
48 | 39 | end |
49 | 40 | end |
50 | 41 | end |
51 | | -TO.tensorcontract_type( |
52 | | - TC, |
53 | | - A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool, |
54 | | - B::AbstractBlockTensorMap, pB::Index2Tuple, conjB::Bool, |
55 | | - pAB::Index2Tuple{N₁, N₂}, |
56 | | -) where {N₁, N₂} = TO.tensorcontract_type(TC, B, pB, conjB, A, pA, conjA, pAB) |
57 | 42 |
|
58 | 43 | function similarblocktype(::Type{A}, ::Type{TT}) where {A, TT} |
59 | 44 | return Core.Compiler.return_type(similar, Tuple{A, Type{TT}, NTuple{numind(TT), Int}}) |
|
0 commit comments