@@ -71,87 +71,109 @@ dim(t::TensorMap) = length(t.data)
7171
7272# General TensorMap constructors
7373# --------------------------------
74- # undef constructors
74+ # hook for mapping input types to storage types -- to be implemented in extensions
75+ _tensormap_storagetype (:: Type{A} ) where {A <: AbstractArray } = _tensormap_storagetype (eltype (A))
76+ _tensormap_storagetype (:: Type{A} ) where {A <: DenseVector{<:Number} } = A
77+ _tensormap_storagetype (:: Type{T} ) where {T <: Number } = Vector{T}
78+
79+ # utility type alias that makes constructors also work for type aliases that specify
80+ # different storage types. (i.e. CuTensorMap = _TensorMap{T, CuVector{T}, ...})
81+ # TODO : do we need a name for this and do we want to consider this as public?
82+ const _TensorMap{T, A <: DenseVector{T} , S, N₁, N₂} = TensorMap{T, S, N₁, N₂, A}
83+ const _Tensor{T, A <: DenseVector{T} , S, N} = Tensor{T, S, N, A}
84+
85+ # undef constructors:
86+ # - dispatch start with TensorMap{T}
87+ # - select A and map to _TensorMap{T, A}
88+ # - select S, N1, N2 and map to TensorMap{T,S,N1,N2,A}
7589"""
76- TensorMap{T}(undef, codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂})
77- where {T,S,N₁,N₂}
90+ TensorMap{T}(undef, codomain::ProductSpace{S, N₁}, domain::ProductSpace{S, N₂}) where {T, S, N₁, N₂}
7891 TensorMap{T}(undef, codomain ← domain)
7992 TensorMap{T}(undef, domain → codomain)
80- # expert mode: select storage type `A`
81- TensorMap{T,S,N₁,N₂,A}(undef, codomain ← domain)
82- TensorMap{T,S,N₁,N₂,A}(undef, domain → domain)
8393
84- Construct a `TensorMap` with uninitialized data.
94+ Construct a `TensorMap` with uninitialized data with elements of type `T` .
8595"""
86- function TensorMap {T} (:: UndefInitializer , V:: TensorMapSpace{S, N₁, N₂} ) where {T, S, N₁, N₂}
87- return TensorMap {T, S, N₁, N₂, Vector{T}} (undef, V)
88- end
89- function TensorMap {T} (
90- :: UndefInitializer , codomain:: TensorSpace{S} , domain:: TensorSpace{S}
91- ) where {T, S}
92- return TensorMap {T} (undef, codomain ← domain)
93- end
94- function Tensor {T} (:: UndefInitializer , V:: TensorSpace{S} ) where {T, S}
95- return TensorMap {T} (undef, V ← one (V))
96- end
96+ TensorMap {T} (:: UndefInitializer , V:: TensorMapSpace ) where {T} =
97+ _TensorMap {T, _tensormap_storagetype(T)} (undef, V)
98+ TensorMap {T} (:: UndefInitializer , codomain:: TensorSpace , domain:: TensorSpace ) where {T} =
99+ TensorMap {T} (undef, codomain ← domain)
100+ Tensor {T} (:: UndefInitializer , V:: TensorSpace ) where {T} = TensorMap {T} (undef, V ← one (V))
101+
102+ # specifying storagetype, fill in other parameters
103+ """
104+ (TensorMap{T, S, N₁, N₂, A} where {S, N₁, N₂})(undef, codomain, domain) where {T, A}
105+ (TensorMap{T, S, N₁, N₂, A} where {S, N₁, N₂})(undef, codomain ← domain) where {T, A}
106+ (TensorMap{T, S, N₁, N₂, A} where {S, N₁, N₂})(undef, domain → codomain) where {T, A}
107+
108+ Construct a `TensorMap` with uninitialized data stored as `A <: DenseVector{T}`.
109+ """
110+ _TensorMap {T, A} (:: UndefInitializer , V:: TensorMapSpace ) where {T, A} =
111+ TensorMap {T, spacetype(V), numout(V), numin(V), A} (undef, V)
112+ _TensorMap {T, A} (:: UndefInitializer , codomain:: TensorSpace , domain:: TensorSpace ) where {T, A} =
113+ _TensorMap {T, A} (undef, codomain ← domain)
114+ _Tensor {T, A} (:: UndefInitializer , V:: TensorSpace ) where {T, A} = _TensorMap {T, A} (undef, V ← one (V))
97115
98116# constructor starting from vector = independent data (N₁ + N₂ = 1 is special cased below)
99117# documentation is captured by the case where `data` is a general array
100- # here, we force the `T` argument to distinguish it from the more general constructor below
101- function TensorMap {T} (
102- data:: A , V:: TensorMapSpace{S, N₁, N₂}
103- ) where {T, S, N₁, N₂, A <: DenseVector{T} }
104- return TensorMap {T, S, N₁, N₂, A} (data, V)
105- end
106- function TensorMap {T} (
107- data:: DenseVector{T} , codomain:: TensorSpace{S} , domain:: TensorSpace{S}
108- ) where {T, S}
109- return TensorMap (data, codomain ← domain)
110- end
118+ # here, we force the `T` and/or `A` argument to distinguish it from the more general constructor below
119+ TensorMap {T} (data:: DenseVector{T} , V:: TensorMapSpace ) where {T} =
120+ _TensorMap {T, typeof(data)} (data, V)
121+ TensorMap {T} (data:: DenseVector{T} , codomain:: TensorSpace , domain:: TensorSpace ) where {T} =
122+ TensorMap {T} (data, codomain ← domain)
123+
124+ _TensorMap {T, A} (data:: DenseVector{T} , V:: TensorMapSpace ) where {T, A} =
125+ TensorMap {T, spacetype(V), numout(V), numin(V), A} (data, V)
126+ _TensorMap {T, A} (data:: DenseVector{T} , codomain:: TensorSpace , domain:: TensorSpace ) where {T, A} =
127+ _TensorMap {T, A} (data, codomain ← domain)
111128
112129# constructor starting from block data
130+ const _BlockData{I <: Sector , A <: AbstractMatrix } = AbstractDict{I, A}
131+
113132"""
114- TensorMap(data::AbstractDict{<:Sector,<:AbstractMatrix}, codomain::ProductSpace{S,N₁},
115- domain::ProductSpace{S,N₂}) where {S<:ElementarySpace,N₁,N₂}
133+ TensorMap(data::AbstractDict{<:Sector, <:AbstractMatrix}, codomain::ProductSpace, domain::ProductSpace)
116134 TensorMap(data, codomain ← domain)
117135 TensorMap(data, domain → codomain)
118136
119137Construct a `TensorMap` by explicitly specifying its block data.
120138
121139## Arguments
122- - `data::AbstractDict{<:Sector,<:AbstractMatrix}`: dictionary containing the block data for
140+ - `data::AbstractDict{<:Sector, <:AbstractMatrix}`: dictionary containing the block data for
123141 each coupled sector `c` as a matrix of size `(blockdim(codomain, c), blockdim(domain, c))`.
124- - `codomain::ProductSpace{S,N₁}`: the codomain as a `ProductSpace` of `N₁` spaces of type
125- `S<: ElementarySpace`.
126- - `domain::ProductSpace{S,N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type
127- `S<: ElementarySpace`.
142+ - `codomain::ProductSpace{S, N₁}`: the codomain as a `ProductSpace` of `N₁` spaces of type
143+ `S <: ElementarySpace`.
144+ - `domain::ProductSpace{S, N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type
145+ `S <: ElementarySpace`.
128146
129147Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref)
130148using the syntax `codomain ← domain` or `domain → codomain`.
131149"""
132- function TensorMap (
133- data:: AbstractDict{<:Sector, <:AbstractMatrix} , V:: TensorMapSpace{S, N₁, N₂}
134- ) where {S, N₁, N₂}
135- T = eltype (valtype (data))
136- t = TensorMap {T} (undef, V)
150+ function TensorMap (data:: _BlockData , V:: TensorMapSpace )
151+ A = _tensormap_storagetype (valtype (data))
152+ return _TensorMap {scalartype(A), A} (data, V)
153+ end
154+ TensorMap (data:: _BlockData , codom:: TensorSpace , dom:: TensorSpace ) =
155+ TensorMap (data, codom ← dom)
156+
157+ function _TensorMap {T, A} (data:: _BlockData , V:: TensorMapSpace ) where {T, A}
158+ t = _TensorMap {T, A} (undef, V)
159+
160+ # check that there aren't too many blocks
161+ for (c, b) in data
162+ c ∈ blocksectors (t) || isempty (b) || throw (SectorMismatch (" data for block sector $c not expected" ))
163+ end
164+
165+ # fill in the blocks -- rely on conversion in copy
137166 for (c, b) in blocks (t)
138167 haskey (data, c) || throw (SectorMismatch (" no data for block sector $c " ))
139168 datac = data[c]
140- size (datac) == size (b) ||
141- throw (DimensionMismatch (" wrong size of block for sector $c " ))
169+ size (datac) == size (b) || throw (DimensionMismatch (" wrong size of block for sector $c " ))
142170 copy! (b, datac)
143171 end
144- for (c, b) in data
145- c ∈ blocksectors (t) || isempty (b) ||
146- throw (SectorMismatch (" data for block sector $c not expected" ))
147- end
172+
148173 return t
149174end
150- function TensorMap (
151- data:: AbstractDict{<:Sector, <:AbstractMatrix} , codom:: TensorSpace{S} , dom:: TensorSpace{S}
152- ) where {S}
153- return TensorMap (data, codom ← dom)
154- end
175+ _TensorMap {T, A} (data:: _BlockData , codom:: TensorSpace , dom:: TensorSpace ) where {T, A} =
176+ _TensorMap {T, A} (data, codom ← dom)
155177
156178@doc """
157179 zeros([T=Float64,], codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂}) where {S,N₁,N₂,T}
@@ -317,49 +339,45 @@ cases.
317339 to a plain array is possible, and only in the case where the `data` actually respects
318340 the specified symmetry structure, up to a tolerance `tol`.
319341"""
320- function TensorMap (
321- data:: AbstractArray , V:: TensorMapSpace{S, N₁, N₂} ;
322- tol = sqrt (eps (real (float (eltype (data)))))
323- ) where {S <: IndexSpace , N₁, N₂}
342+ function TensorMap (data:: AbstractArray , V:: TensorMapSpace ; tol = sqrt (eps (real (float (eltype (data))))))
324343 T = eltype (data)
325- if ndims (data) == 1 && length (data) == dim (V)
326- if data isa DenseVector # refer to specific data-capturing constructor
327- return TensorMap {T} (data, V)
328- else
329- return TensorMap {T} (collect (data), V)
330- end
331- end
344+ A = _tensormap_storagetype (typeof (data))
345+ return _TensorMap {T, A} (data, V; tol)
346+ end
347+ function _TensorMap {T, A} (
348+ data:: AbstractArray , V:: TensorMapSpace ; tol = sqrt (eps (real (float (eltype (data)))))
349+ ) where {T, A}
350+ # refer to specific data-capturing constructors if input is a vector of the correct length
351+ ndims (data) == 1 && length (data) == dim (V) && return _TensorMap {T, A} (data, V)
332352
333353 # dimension check
334- codom = codomain (V)
335- dom = domain (V)
354+ codom, dom = codomain (V), domain (V)
336355 arraysize = dims (V)
337356 matsize = (dim (codom), dim (dom))
357+ (size (data) == arraysize || size (data) == matsize) || throw (DimensionMismatch ())
338358
339- if ! (size (data) == arraysize || size (data) == matsize)
340- throw (DimensionMismatch ())
341- end
342-
343- if sectortype (S) === Trivial # refer to same method, but now with vector argument
344- return TensorMap (reshape (data, length (data)), V)
359+ if sectortype (V) === Trivial # refer to same method, but now with vector argument
360+ return _TensorMap {T, A} (reshape (data, length (data)), V)
345361 end
346362
347- t = TensorMap {T } (undef, codom, dom )
363+ t = _TensorMap {T, A } (undef, V )
348364 arraydata = reshape (collect (data), arraysize)
349365 t = project_symmetric! (t, arraydata)
350366 if ! isapprox (arraydata, convert (Array, t); atol = tol)
351367 throw (ArgumentError (" Data has non-zero elements at incompatible positions" ))
352368 end
353369 return t
354370end
355- function TensorMap (
356- data:: AbstractArray , codom:: TensorSpace{S} , dom:: TensorSpace{S} ; kwargs...
357- ) where {S}
358- return TensorMap (data, codom ← dom; kwargs... )
359- end
360- function Tensor (data:: AbstractArray , codom:: TensorSpace , ; kwargs... )
361- return TensorMap (data, codom ← one (codom); kwargs... )
362- end
371+
372+ TensorMap (data:: AbstractArray , codom:: TensorSpace , dom:: TensorSpace ; kwargs... ) =
373+ TensorMap (data, codom ← dom; kwargs... )
374+ _TensorMap {T, A} (data:: AbstractArray , codom:: TensorSpace , dom:: TensorSpace ; kwargs... ) where {T, A} =
375+ _TensorMap (data, codom ← dom; kwargs... )
376+
377+ Tensor (data:: AbstractArray , codom:: TensorSpace ; kwargs... ) =
378+ TensorMap (data, codom ← one (codom); kwargs... )
379+ _Tensor {T, A} (data:: AbstractArray , codom:: TensorSpace ; kwargs... ) where {T, A} =
380+ _TensorMap {T, A} (data, codom ← one (codom); kwargs... )
363381
364382"""
365383 project_symmetric!(t::TensorMap, data::DenseArray) -> TensorMap
0 commit comments