|
1 | | -function TK.add_transform!( |
2 | | - tdst::BlockTensorMap, tsrc::BlockTensorMap, (p₁, p₂)::Index2Tuple{N₁, N₂}, |
3 | | - fusiontreetransform, |
4 | | - α::Number, β::Number, |
5 | | - backend::AbstractBackend..., |
6 | | - ) where {N₁, N₂} |
7 | | - @boundscheck begin |
8 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
9 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
10 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
11 | | - end |
| 1 | +@propagate_inbounds function TK.add_transform!( |
| 2 | + tdst::BlockTensorMap, tsrc::BlockTensorMap, p::Index2Tuple, transformer, |
| 3 | + α::Number, β::Number, backend, allocator |
| 4 | + ) |
| 5 | + @boundscheck TK.spacecheck_transform(permute, tdst, tsrc, p) |
| 6 | + |
12 | 7 | dstdata = parent(tdst) |
13 | | - srcdata = permutedims(StridedView(parent(tsrc)), (p₁..., p₂...)) |
| 8 | + srcdata = permutedims(StridedView(parent(tsrc)), (p[1]..., p[2]...)) |
14 | 9 |
|
15 | 10 | @inbounds for I in eachindex(dstdata, srcdata) |
16 | 11 | dstdata[I] = TK.add_transform!( |
17 | | - dstdata[I], srcdata[I], (p₁, p₂), fusiontreetransform, α, β, backend... |
| 12 | + dstdata[I], srcdata[I], p, transformer, α, β, backend, allocator |
18 | 13 | ) |
19 | 14 | end |
20 | 15 | return tdst |
21 | 16 | end |
22 | | -function TK.add_transform!( |
| 17 | +@propagate_inbounds function TK.add_transform!( |
23 | 18 | tdst::AbstractBlockTensorMap, tsrc::AbstractBlockTensorMap, |
24 | | - (p₁, p₂)::Index2Tuple{N₁, N₂}, |
25 | | - fusiontreetransform, |
26 | | - α::Number, β::Number, |
27 | | - backend::AbstractBackend..., |
28 | | - ) where {N₁, N₂} |
29 | | - @boundscheck begin |
30 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
31 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
32 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
33 | | - end |
| 19 | + p::Index2Tuple, transformer, α::Number, β::Number, backend, allocator |
| 20 | + ) |
| 21 | + @boundscheck TK.spacecheck_transform(permute, tdst, tsrc, p) |
34 | 22 | scale!(tdst, β) |
35 | | - p = (p₁..., p₂...) |
| 23 | + p_lin = (p[1]..., p[2]...) |
36 | 24 | @inbounds for (I, v) in nonzero_pairs(tsrc) |
37 | | - I′ = CartesianIndex(TT.getindices(I.I, p)) |
| 25 | + I′ = CartesianIndex(TT.getindices(I.I, p_lin)) |
38 | 26 | tdst[I′] = TK.add_transform!( |
39 | | - tdst[I′], v, (p₁, p₂), fusiontreetransform, α, One(), backend... |
| 27 | + tdst[I′], v, p_lin, transformer, α, One(), backend, allocator |
40 | 28 | ) |
41 | 29 | end |
42 | 30 | return tdst |
43 | 31 | end |
44 | 32 | function TK.add_transform!( |
45 | 33 | tdst::AbstractBlockTensorMap, tsrc::AdjointTensorMap{T, S, N₁, N₂, TT}, |
46 | | - (p₁, p₂)::Index2Tuple, |
47 | | - fusiontreetransform, |
48 | | - α::Number, β::Number, |
49 | | - backend::AbstractBackend..., |
| 34 | + p::Index2Tuple, transformer, α::Number, β::Number, backend, allocator |
50 | 35 | ) where {T, S, N₁, N₂, TT <: AbstractBlockTensorMap} |
51 | | - @boundscheck begin |
52 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
53 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
54 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
55 | | - end |
| 36 | + @boundscheck TK.spacecheck_transform(permute, tdst, tsrc, p) |
56 | 37 | scale!(tdst, β) |
57 | | - p = (p₁..., p₂...) |
| 38 | + p_lin = (p[1]..., p[2]...) |
58 | 39 | @inbounds for (I, v) in nonzero_pairs(tsrc) |
59 | | - I′ = CartesianIndex(TT.getindices(I.I, p)) |
| 40 | + I′ = CartesianIndex(TT.getindices(I.I, p_lin)) |
60 | 41 | tdst[I′] = TK.add_transform!( |
61 | | - tdst[I′], v, (p₁, p₂), fusiontreetransform, α, One(), backend... |
| 42 | + tdst[I′], v, p, transformer, α, One(), backend, allocator |
62 | 43 | ) |
63 | 44 | end |
64 | 45 | return tdst |
65 | 46 | end |
66 | 47 | function TK.add_transform!( |
67 | | - tdst::TensorMap, tsrc::BlockTensorMap, (p₁, p₂)::Index2Tuple, |
68 | | - fusiontreetransform, |
69 | | - α::Number, β::Number, |
70 | | - backend::AbstractBackend..., |
| 48 | + tdst::TensorMap, tsrc::BlockTensorMap, p::Index2Tuple, transformer, |
| 49 | + α::Number, β::Number, backend, allocator |
71 | 50 | ) |
72 | | - @assert length(tsrc) == 1 "source tensor must be a single tensor" |
73 | 51 | return TK.add_transform!( |
74 | | - tdst, only(tsrc), (p₁, p₂), fusiontreetransform, α, β, backend... |
| 52 | + tdst, only(tsrc), p, transformer, α, β, backend, allocator |
75 | 53 | ) |
76 | 54 | end |
77 | 55 | function TK.add_transform!( |
78 | | - tdst::BlockTensorMap, tsrc::TensorMap, |
79 | | - (p₁, p₂)::Index2Tuple, |
80 | | - fusiontreetransform, |
81 | | - α::Number, β::Number, |
82 | | - backend::AbstractBackend..., |
| 56 | + tdst::BlockTensorMap, tsrc::TensorMap, p::Index2Tuple, transformer, |
| 57 | + α::Number, β::Number, backend, allocator |
83 | 58 | ) |
84 | | - # @assert length(tsrc) == 1 "source tensor must be a single tensor" |
85 | 59 | return TK.add_transform!( |
86 | | - only(tdst), tsrc, (p₁, p₂), fusiontreetransform, α, β, backend... |
| 60 | + only(tdst), tsrc, p, transformer, α, β, backend, allocator |
87 | 61 | ) |
88 | 62 | end |
89 | 63 |
|
90 | 64 | # we need to capture the other functions earlier to enjoy the fast transformers... |
91 | | -for f! in (:add_permute!, :add_transpose!) |
92 | | - @eval function TK.$f!( |
93 | | - tdst::BlockTensorMap, tsrc::BlockTensorMap, |
94 | | - (p₁, p₂)::Index2Tuple{N₁, N₂}, |
95 | | - α::Number, β::Number, |
96 | | - backend::AbstractBackend..., |
97 | | - ) where {N₁, N₂} |
98 | | - @boundscheck begin |
99 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
100 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
101 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
102 | | - end |
103 | | - dstdata = parent(tdst) |
104 | | - srcdata = permutedims(StridedView(parent(tsrc)), (p₁..., p₂...)) |
| 65 | +for f in (:permute, :transpose) |
| 66 | + f! = Symbol(f, :!) |
| 67 | + @eval begin |
| 68 | + function TK.$f!( |
| 69 | + tdst::BlockTensorMap, tsrc::BlockTensorMap, p::Index2Tuple, |
| 70 | + α::Number, β::Number, backend::AbstractBackend, allocator |
| 71 | + ) |
| 72 | + @boundscheck TK.spacecheck_transform(TK.$f, tdst, tsrc, p) |
105 | 73 |
|
106 | | - @inbounds for I in eachindex(dstdata, srcdata) |
107 | | - dstdata[I] = TK.$f!(dstdata[I], srcdata[I], (p₁, p₂), α, β, backend...) |
| 74 | + dstdata = parent(tdst) |
| 75 | + srcdata = permutedims(StridedView(parent(tsrc)), (p[1]..., p[2]...)) |
| 76 | + |
| 77 | + @inbounds for I in eachindex(dstdata, srcdata) |
| 78 | + dstdata[I] = TK.$f!(dstdata[I], srcdata[I], p, α, β, backend, allocator) |
| 79 | + end |
| 80 | + return tdst |
108 | 81 | end |
109 | | - return tdst |
110 | | - end |
111 | | - @eval function TK.$f!( |
112 | | - tdst::AbstractBlockTensorMap, tsrc::AbstractBlockTensorMap, |
113 | | - (p₁, p₂)::Index2Tuple{N₁, N₂}, |
114 | | - α::Number, β::Number, |
115 | | - backend::AbstractBackend..., |
116 | | - ) where {N₁, N₂} |
117 | | - @boundscheck begin |
118 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
119 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
120 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
| 82 | + function TK.$f!( |
| 83 | + tdst::AbstractBlockTensorMap, tsrc::AbstractBlockTensorMap, |
| 84 | + p::Index2Tuple, α::Number, β::Number, backend::AbstractBackend, allocator |
| 85 | + ) |
| 86 | + @boundscheck TK.spacecheck_transform(TK.$f, tdst, tsrc, p) |
| 87 | + scale!(tdst, β) |
| 88 | + p_lin = (p[1]..., p[2]...) |
| 89 | + @inbounds for (I, v) in nonzero_pairs(tsrc) |
| 90 | + I′ = CartesianIndex(TT.getindices(I.I, p_lin)) |
| 91 | + tdst[I′] = TK.$f!(tdst[I′], v, p, α, One(), backend, allocator) |
| 92 | + end |
| 93 | + return tdst |
121 | 94 | end |
122 | | - scale!(tdst, β) |
123 | | - p = (p₁..., p₂...) |
124 | | - @inbounds for (I, v) in nonzero_pairs(tsrc) |
125 | | - I′ = CartesianIndex(TT.getindices(I.I, p)) |
126 | | - tdst[I′] = TK.$f!(tdst[I′], v, (p₁, p₂), α, One(), backend...) |
| 95 | + function TK.$f!( |
| 96 | + tdst::AbstractBlockTensorMap, tsrc::AdjointTensorMap{T, S, N₁, N₂, TT}, |
| 97 | + p::Index2Tuple, α::Number, β::Number, backend::AbstractBackend, allocator |
| 98 | + ) where {T, S, N₁, N₂, TT <: AbstractBlockTensorMap} |
| 99 | + @boundscheck TK.spacecheck_transform(TK.$f, tdst, tsrc, p) |
| 100 | + scale!(tdst, β) |
| 101 | + p_lin = (p[1]..., p[2]...) |
| 102 | + @inbounds for (I, v) in nonzero_pairs(tsrc) |
| 103 | + I′ = CartesianIndex(TT.getindices(I.I, p)) |
| 104 | + tdst[I′] = TK.$f!(tdst[I′], v, (p₁, p₂), α, One(), backend, allocator) |
| 105 | + end |
| 106 | + return tdst |
127 | 107 | end |
128 | | - return tdst |
129 | | - end |
130 | | - @eval function TK.$f!( |
131 | | - tdst::AbstractBlockTensorMap, tsrc::AdjointTensorMap{T, S, N₁, N₂, TT}, |
132 | | - (p₁, p₂)::Index2Tuple, |
133 | | - α::Number, β::Number, |
134 | | - backend::AbstractBackend..., |
135 | | - ) where {T, S, N₁, N₂, TT <: AbstractBlockTensorMap} |
136 | | - @boundscheck begin |
137 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
138 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
139 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
| 108 | + function TK.$f!( |
| 109 | + tdst::TensorMap, tsrc::BlockTensorMap, p::Index2Tuple, |
| 110 | + α::Number, β::Number, backend::AbstractBackend, allocator |
| 111 | + ) |
| 112 | + return TK.$f!(tdst, only(tsrc), p, α, β, backend, allocator) |
140 | 113 | end |
141 | | - scale!(tdst, β) |
142 | | - p = (p₁..., p₂...) |
143 | | - @inbounds for (I, v) in nonzero_pairs(tsrc) |
144 | | - I′ = CartesianIndex(TT.getindices(I.I, p)) |
145 | | - tdst[I′] = TK.$f!(tdst[I′], v, (p₁, p₂), α, One(), backend...) |
| 114 | + function TK.$f!( |
| 115 | + tdst::BlockTensorMap, tsrc::TensorMap, p::Index2Tuple, |
| 116 | + α::Number, β::Number, backend::AbstractBackend, allocator |
| 117 | + ) |
| 118 | + TK.$f!(only(tdst), tsrc, p, α, β, backend, allocator) |
| 119 | + return tdst |
146 | 120 | end |
147 | | - return tdst |
148 | | - end |
149 | | - @eval function TK.$f!( |
150 | | - tdst::TensorMap, tsrc::BlockTensorMap, |
151 | | - (p₁, p₂)::Index2Tuple, |
152 | | - α::Number, β::Number, |
153 | | - backend::AbstractBackend..., |
154 | | - ) |
155 | | - @assert length(tsrc) == 1 "source tensor must be a single tensor" |
156 | | - return TK.$f!(tdst, only(tsrc), (p₁, p₂), α, β, backend...) |
157 | 121 | end |
158 | 122 | end |
159 | 123 |
|
160 | | -function TK.add_braid!( |
161 | | - tdst::BlockTensorMap, tsrc::BlockTensorMap, |
162 | | - (p₁, p₂)::Index2Tuple{N₁, N₂}, |
163 | | - levels::IndexTuple, |
164 | | - α::Number, β::Number, |
165 | | - backend::AbstractBackend..., |
166 | | - ) where {N₁, N₂} |
167 | | - @boundscheck begin |
168 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
169 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
170 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
171 | | - end |
| 124 | +@propagate_inbounds function TK.braid!( |
| 125 | + tdst::BlockTensorMap, tsrc::BlockTensorMap, p::Index2Tuple, levels::IndexTuple, |
| 126 | + α::Number, β::Number, backend::AbstractBackend, allocator |
| 127 | + ) |
| 128 | + @boundscheck TK.spacecheck_transform(braid, tdst, tsrc, p, levels) |
| 129 | + |
172 | 130 | dstdata = parent(tdst) |
173 | | - srcdata = permutedims(StridedView(parent(tsrc)), (p₁..., p₂...)) |
| 131 | + srcdata = permutedims(StridedView(parent(tsrc)), (p[1]..., p[2]...)) |
174 | 132 |
|
175 | 133 | @inbounds for I in eachindex(dstdata, srcdata) |
176 | | - dstdata[I] = TK.add_braid!( |
177 | | - dstdata[I], srcdata[I], (p₁, p₂), levels, α, β, backend... |
| 134 | + dstdata[I] = TK.braid!( |
| 135 | + dstdata[I], srcdata[I], p, levels, α, β, backend, allocator |
178 | 136 | ) |
179 | 137 | end |
180 | 138 | return tdst |
181 | 139 | end |
182 | | -function TK.add_braid!( |
| 140 | +@propagate_inbounds function TK.braid!( |
183 | 141 | tdst::AbstractBlockTensorMap, tsrc::AbstractBlockTensorMap, |
184 | | - (p₁, p₂)::Index2Tuple{N₁, N₂}, |
185 | | - levels::IndexTuple, |
186 | | - α::Number, β::Number, |
187 | | - backend::AbstractBackend..., |
188 | | - ) where {N₁, N₂} |
189 | | - @boundscheck begin |
190 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
191 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
192 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
193 | | - end |
| 142 | + p::Index2Tuple, levels::IndexTuple, |
| 143 | + α::Number, β::Number, backend::AbstractBackend, allocator |
| 144 | + ) |
| 145 | + @boundscheck TK.spacecheck_transform(braid, tdst, tsrc, p, levels) |
194 | 146 | scale!(tdst, β) |
195 | | - p = (p₁..., p₂...) |
| 147 | + p_lin = (p[1]..., p[2]...) |
196 | 148 | @inbounds for (I, v) in nonzero_pairs(tsrc) |
197 | | - I′ = CartesianIndex(TT.getindices(I.I, p)) |
198 | | - tdst[I′] = TK.add_braid!(tdst[I′], v, (p₁, p₂), levels, α, One(), backend...) |
| 149 | + I′ = CartesianIndex(TT.getindices(I.I, p_lin)) |
| 150 | + tdst[I′] = TK.braid!(tdst[I′], v, p, levels, α, One(), backend, allocator) |
199 | 151 | end |
200 | 152 | return tdst |
201 | 153 | end |
202 | | -function TK.add_braid!( |
| 154 | +function TK.braid!( |
203 | 155 | tdst::AbstractBlockTensorMap, tsrc::AdjointTensorMap{T, S, N₁, N₂, TT}, |
204 | | - (p₁, p₂)::Index2Tuple, |
205 | | - levels::IndexTuple, |
206 | | - α::Number, β::Number, |
207 | | - backend::AbstractBackend..., |
| 156 | + p::Index2Tuple, levels::IndexTuple, α::Number, β::Number, backend::AbstractBackend, allocator |
208 | 157 | ) where {T, S, N₁, N₂, TT <: AbstractBlockTensorMap} |
209 | | - @boundscheck begin |
210 | | - permute(space(tsrc), (p₁, p₂)) == space(tdst) || |
211 | | - throw(SpaceMismatch("source = $(codomain(tsrc))←$(domain(tsrc)), |
212 | | - dest = $(codomain(tdst))←$(domain(tdst)), p₁ = $(p₁), p₂ = $(p₂)")) |
213 | | - end |
| 158 | + @boundscheck TK.spacecheck_transform(braid, tdst, tsrc, p, levels) |
214 | 159 | scale!(tdst, β) |
215 | | - p = (p₁..., p₂...) |
| 160 | + p_lin = (p[1]..., p[2]...) |
216 | 161 | @inbounds for (I, v) in nonzero_pairs(tsrc) |
217 | | - I′ = CartesianIndex(TT.getindices(I.I, p)) |
218 | | - tdst[I′] = TK.add_braid!(tdst[I′], v, (p₁, p₂), levels, α, One(), backend...) |
| 162 | + I′ = CartesianIndex(TT.getindices(I.I, p_lin)) |
| 163 | + tdst[I′] = TK.braid!(tdst[I′], v, p, levels, α, One(), backend, allocator) |
219 | 164 | end |
220 | 165 | return tdst |
221 | 166 | end |
222 | | -function TK.add_braid!( |
| 167 | +function TK.braid!( |
223 | 168 | tdst::TensorMap, tsrc::BlockTensorMap, |
224 | | - (p₁, p₂)::Index2Tuple, |
225 | | - levels::IndexTuple, |
226 | | - α::Number, β::Number, |
227 | | - backend::AbstractBackend..., |
| 169 | + p::Index2Tuple, levels::IndexTuple, |
| 170 | + α::Number, β::Number, backend::AbstractBackend, allocator |
| 171 | + ) |
| 172 | + return TK.add_braid!(tdst, only(tsrc), p, levels, α, β, backend, allocator) |
| 173 | +end |
| 174 | +function TK.braid!( |
| 175 | + tdst::BlockTensorMap, tsrc::TensorMap, |
| 176 | + p::Index2Tuple, levels::IndexTuple, |
| 177 | + α::Number, β::Number, backend::AbstractBackend, allocator |
228 | 178 | ) |
229 | | - @assert length(tsrc) == 1 "source tensor must be a single tensor" |
230 | | - return TK.add_braid!(tdst, only(tsrc), (p₁, p₂), levels, α, β, backend...) |
| 179 | + TK.braid!(only(tdst), tsrc, p, levels, α, β, backend, allocator) |
| 180 | + return tdst |
231 | 181 | end |
232 | 182 |
|
233 | 183 | Base.@constprop :aggressive function TK.insertleftunit( |
|
0 commit comments