Skip to content

Commit c00af1f

Browse files
committed
add support for cat
1 parent 5c4199b commit c00af1f

2 files changed

Lines changed: 35 additions & 0 deletions

File tree

src/tensors/abstractblocktensor/abstractarray.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,25 @@ end
244244
function Base.similar(::Type{T}, P::TensorMapSumSpace) where {T<:AbstractBlockTensorMap}
245245
return T(undef, P)
246246
end
247+
248+
# Cat
249+
# ---
250+
Base.eltypeof(t::AbstractBlockTensorMap) = eltype(t)
251+
252+
@inline function Base._cat_t(
253+
dims, ::Type{T}, ts::AbstractBlockTensorMap...
254+
) where {T<:AbstractTensorMap}
255+
catdims = Base.dims2cat(dims)
256+
V = space(Base._cat(dims, eachspace.(ts)...))
257+
A = similar(ts[1], T, V)
258+
shape = size(A)
259+
if count(!iszero, catdims)::Int > 1
260+
zerovector!(A)
261+
end
262+
return Base.__cat(A, shape, catdims, ts...)
263+
end
264+
265+
Base._copy_or_fill!(A, inds, x::AbstractBlockTensorMap) = (A[inds...] = x)
266+
267+
# WHY DOES BASE NOT DEFAULT TO AXES
268+
Base.cat_indices(A::AbstractBlockTensorMap, d) = axes(A, d)

src/vectorspaces/sumspaceindices.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,16 @@ function subblockdims(V::ProductSumSpace{S,N}, c::Sector) where {S,N}
117117
)
118118
end
119119
end
120+
121+
function Base._cat(dims, A::SumSpaceIndices{S,N₁,N₂}...) where {S,N₁,N₂}
122+
@assert maximum(dims) <= N₁ + N₂ "Invalid number of spaces"
123+
catdims = Base.dims2cat(dims)
124+
Vs = ntuple(N₁ + N₂) do i
125+
return if i <= length(catdims) && catdims[i]
126+
((A[j].sumspaces[i] for j in 1:length(A))...)
127+
else
128+
A[1].sumspaces[i]
129+
end
130+
end
131+
return SumSpaceIndices{S,N₁,N₂}(Vs)
132+
end

0 commit comments

Comments
 (0)