Skip to content

Commit c4fdb2d

Browse files
committed
unify tensortype usage
also update `tensoralloc`
1 parent 2fd9f12 commit c4fdb2d

4 files changed

Lines changed: 45 additions & 56 deletions

File tree

src/auxiliary/auxiliary.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,6 @@ end
8686
else
8787
_allequal(f, xs) = allequal(f, xs)
8888
end
89+
90+
Base.@assume_effects :foldable parenttype(::Type{T}) where {T} =
91+
Core.Compiler.return_type(parent, Tuple{T})

src/tensors/abstracttensor.jl

Lines changed: 20 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -496,54 +496,39 @@ See also [`similar_diagonal`](@ref).
496496
""" Base.similar(::AbstractTensorMap, args...)
497497

498498
function Base.similar(
499-
t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S}
500-
) where {T, S}
499+
t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace, domain::TensorSpace
500+
) where {T}
501501
return similar(t, T, codomain domain)
502502
end
503+
503504
# 3 arguments
504-
function Base.similar(
505-
t::AbstractTensorMap, codomain::TensorSpace{S}, domain::TensorSpace{S}
506-
) where {S}
507-
return similar(t, similarstoragetype(t), codomain domain)
508-
end
509-
function Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T}
510-
return similar(t, T, codomain one(codomain))
511-
end
505+
Base.similar(t::AbstractTensorMap, codomain::TensorSpace, domain::TensorSpace) =
506+
similar(t, similarstoragetype(t), codomain domain)
507+
Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T} =
508+
similar(t, T, codomain one(codomain))
509+
512510
# 2 arguments
513-
function Base.similar(t::AbstractTensorMap, codomain::TensorSpace)
514-
return similar(t, similarstoragetype(t), codomain one(codomain))
515-
end
516-
Base.similar(t::AbstractTensorMap, P::TensorMapSpace) = similar(t, storagetype(t), P)
511+
Base.similar(t::AbstractTensorMap, codomain::TensorSpace) =
512+
similar(t, similarstoragetype(t), codomain one(codomain))
513+
Base.similar(t::AbstractTensorMap, V::TensorMapSpace) = similar(t, similarstoragetype(t), V)
517514
Base.similar(t::AbstractTensorMap, ::Type{T}) where {T} = similar(t, T, space(t))
518515
# 1 argument
519516
Base.similar(t::AbstractTensorMap) = similar(t, similarstoragetype(t), space(t))
520517

521518
# generic implementation for AbstractTensorMap -> returns `TensorMap`
522-
function Base.similar(t::AbstractTensorMap, ::Type{TorA}, P::TensorMapSpace{S}) where {TorA, S}
523-
if TorA <: Number
524-
T = TorA
525-
A = similarstoragetype(t, T)
526-
elseif TorA <: DenseVector
527-
A = TorA
528-
T = scalartype(A)
529-
else
530-
throw(ArgumentError("Type $TorA not supported for similar"))
531-
end
532-
533-
N₁ = length(codomain(P))
534-
N₂ = length(domain(P))
535-
return TensorMap{T, S, N₁, N₂, A}(undef, P)
519+
function Base.similar(t::AbstractTensorMap, ::Type{TorA}, V::TensorMapSpace) where {TorA}
520+
A = TorA <: Number ? similarstoragetype(t, TorA) : TorA
521+
TT = tensormaptype(spacetype(V), numout(V), numin(V), A)
522+
return TT(undef, V)
536523
end
537524

538525
# implementation in type-domain
539-
function Base.similar(::Type{TT}, P::TensorMapSpace) where {TT <: AbstractTensorMap}
540-
return TensorMap{scalartype(TT)}(undef, P)
541-
end
542-
function Base.similar(
543-
::Type{TT}, cod::TensorSpace{S}, dom::TensorSpace{S}
544-
) where {TT <: AbstractTensorMap, S}
545-
return TensorMap{scalartype(TT)}(undef, cod, dom)
526+
function Base.similar(::Type{TT}, V::TensorMapSpace) where {TT <: AbstractTensorMap}
527+
TT′ = tensormaptype(spacetype(V), numout(V), numin(V), similarstoragetype(TT))
528+
return TT′(undef, V)
546529
end
530+
Base.similar(::Type{TT}, cod::TensorSpace, dom::TensorSpace) where {TT <: AbstractTensorMap} =
531+
similar(TT, cod dom)
547532

548533
# similar diagonal
549534
# ----------------

src/tensors/tensor.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ struct TensorMap{T, S <: IndexSpace, N₁, N₂, A <: DenseVector{T}} <: Abstrac
3131
I = sectortype(S)
3232
T <: Real && !(sectorscalartype(I) <: Real) &&
3333
@warn("Tensors with real data might be incompatible with sector type $I", maxlog = 1)
34+
d = fusionblockstructure(space).totaldim
35+
length(data) == d || throw(DimensionMismatch("invalid length of data"))
3436
return new{T, S, N₁, N₂, A}(data, space)
3537
end
3638
end
@@ -47,19 +49,20 @@ i.e. a tensor map with only a non-trivial output space.
4749
const Tensor{T, S, N, A} = TensorMap{T, S, N, 0, A}
4850

4951
function tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type)
50-
if TorA <: Number
51-
return TensorMap{TorA, S, N₁, N₂, Vector{TorA}}
52-
elseif TorA <: DenseVector
53-
return TensorMap{scalartype(TorA), S, N₁, N₂, TorA}
54-
else
55-
throw(ArgumentError("argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:DenseVector{<:Number}`"))
56-
end
52+
A = _tensormap_storagetype(TorA)
53+
A <: DenseVector || throw(ArgumentError("Cannot determine a valid storage type from argument $TorA"))
54+
return TensorMap{scalartype(A), S, N₁, N₂, A}
5755
end
5856

5957
# hook for mapping input types to storage types -- to be implemented in extensions
60-
_tensormap_storagetype(::Type{A}) where {A <: AbstractArray} = _tensormap_storagetype(scalartype(A))
6158
_tensormap_storagetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
59+
_tensormap_storagetype(::Type{A}) where {A <: Array} = _tensormap_storagetype(scalartype(A))
6260
_tensormap_storagetype(::Type{T}) where {T <: Number} = Vector{T}
61+
function _tensormap_storagetype(::Type{A}) where {A <: AbstractArray}
62+
PA = parenttype(A)
63+
PA === A && throw(MethodError(_tensormap_storagetype, A)) # avoid infinite recursion
64+
return _tensormap_storagetype(PA)
65+
end
6366

6467
# Basic methods for characterising a tensor:
6568
#--------------------------------------------
@@ -95,7 +98,7 @@ const TensorWithStorage{T, A <: DenseVector{T}, S, N} = Tensor{T, S, N, A}
9598
Construct a `TensorMap` with uninitialized data with elements of type `T`.
9699
"""
97100
TensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T} =
98-
TensorMapWithStorage{T, _tensormap_storagetype(T)}(undef, V)
101+
tensormaptype(spacetype(V), numout(V), numin(V), T)(undef, V)
99102
TensorMap{T}(::UndefInitializer, codomain::TensorSpace, domain::TensorSpace) where {T} =
100103
TensorMap{T}(undef, codomain domain)
101104
Tensor{T}(::UndefInitializer, V::TensorSpace) where {T} = TensorMap{T}(undef, V one(V))
@@ -108,7 +111,7 @@ Tensor{T}(::UndefInitializer, V::TensorSpace) where {T} = TensorMap{T}(undef, V
108111
Construct a `TensorMap` with uninitialized data stored as `A <: DenseVector{T}`.
109112
"""
110113
TensorMapWithStorage{T, A}(::UndefInitializer, V::TensorMapSpace) where {T, A} =
111-
TensorMap{T, spacetype(V), numout(V), numin(V), A}(undef, V)
114+
tensormaptype(spacetype(V), numout(V), numin(V), A)(undef, V)
112115
TensorMapWithStorage{T, A}(::UndefInitializer, codomain::TensorSpace, domain::TensorSpace) where {T, A} =
113116
TensorMapWithStorage{T, A}(undef, codomain domain)
114117
TensorWithStorage{T, A}(::UndefInitializer, V::TensorSpace) where {T, A} = TensorMapWithStorage{T, A}(undef, V one(V))
@@ -128,7 +131,7 @@ Construct a `TensorMap` from the given raw data.
128131
This constructor takes ownership of the provided vector, and will not make an independent copy.
129132
"""
130133
TensorMap{T}(data::DenseVector{T}, V::TensorMapSpace) where {T} =
131-
TensorMapWithStorage{T, typeof(data)}(data, V)
134+
tensormaptype(spacetype(V), numout(V), numin(V), typeof(data))(data, V)
132135
TensorMap{T}(data::DenseVector{T}, codomain::TensorSpace, domain::TensorSpace) where {T} =
133136
TensorMap{T}(data, codomain domain)
134137

@@ -141,8 +144,7 @@ Construct a `TensorMap` from the given raw data.
141144
This constructor takes ownership of the provided vector, and will not make an independent copy.
142145
"""
143146
function TensorMapWithStorage{T, A}(data::A, V::TensorMapSpace) where {T, A}
144-
length(data) == dim(V) || throw(DimensionMismatch("invalid length of data"))
145-
return TensorMap{T, spacetype(V), numout(V), numin(V), A}(data, V)
147+
return tensormaptype(spacetype(V), numout(V), numin(V), typeof(data))(data, V)
146148
end
147149
TensorMapWithStorage{T, A}(data::A, codomain::TensorSpace, domain::TensorSpace) where {T, A} =
148150
TensorMapWithStorage{T, A}(data, codomain domain)
@@ -213,11 +215,11 @@ function TensorMapWithStorage{T, A}(
213215
) where {T, A}
214216
# refer to specific raw data constructors if input is a vector of the correct length
215217
ndims(data) == 1 && length(data) == dim(V) &&
216-
return TensorMap{T, spacetype(V), numout(V), numin(V), A}(data, V)
218+
return tensormaptype(spacetype(V), numout(V), numin(V), A)(data, V)
217219

218220
# special case trivial: refer to same method, but now with vector argument
219221
sectortype(V) === Trivial &&
220-
return TensorMap{T, spacetype(V), numout(V), numin(V), A}(reshape(data, length(data)), V)
222+
return tensormaptype(spacetype(V), numout(V), numin(V), A)(reshape(data, length(data)), V)
221223

222224
# do projection
223225
t = TensorMapWithStorage{T, A}(undef, V)
@@ -230,7 +232,7 @@ function TensorMapWithStorage{T, A}(
230232
return t
231233
end
232234
TensorMapWithStorage{T, A}(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) where {T, A} =
233-
TensorMapWithStorage(data, codom dom; kwargs...)
235+
TensorMapWithStorage{T, A}(data, codom dom; kwargs...)
234236
TensorWithStorage{T, A}(data::AbstractArray, codom::TensorSpace; kwargs...) where {T, A} =
235237
TensorMapWithStorage{T, A}(data, codom one(codom); kwargs...)
236238

src/tensors/tensoroperations.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@ function TO.tensorstructure(t::AbstractTensorMap, iA::Int, conjA::Bool)
66
end
77

88
function TO.tensoralloc(
9-
::Type{TT}, structure::TensorMapSpace{S, N₁, N₂},
10-
istemp::Val, allocator = TO.DefaultAllocator()
11-
) where {T, S, N₁, N₂, TT <: AbstractTensorMap{T, S, N₁, N₂}}
9+
::Type{TT}, structure::TensorMapSpace, istemp::Val, allocator = TO.DefaultAllocator()
10+
) where {TT <: AbstractTensorMap}
1211
A = storagetype(TT)
1312
dim = fusionblockstructure(structure).totaldim
1413
data = TO.tensoralloc(A, dim, istemp, allocator)
15-
# return TT(data, structure)
16-
return TensorMap{T}(data, structure)
14+
TT′ = tensormaptype(spacetype(structure), numout(structure), numin(structure), typeof(data))
15+
return TT′(data, structure)
1716
end
1817

1918
function TO.tensorfree!(t::TensorMap, allocator = TO.DefaultAllocator())

0 commit comments

Comments
 (0)