Skip to content

Commit 062081e

Browse files
committed
rework tensor contructors
try handle PtrArrays fix docstrings
1 parent 152ea71 commit 062081e

2 files changed

Lines changed: 101 additions & 83 deletions

File tree

docs/src/lib/tensors.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Tensor
3434

3535
A `TensorMap` with undefined data can be constructed by specifying its domain and codomain:
3636
```@docs
37-
TensorMap{T}(::UndefInitializer, V::TensorMapSpace{S,N₁,N₂}) where {T,S,N₁,N₂}
37+
TensorMap{T}(::UndefInitializer, V::TensorMapSpace)
3838
```
3939

4040
The resulting object can then be filled with data using the `setindex!` method as discussed
@@ -45,8 +45,8 @@ in an `@tensor output[...] = ...` expression.
4545
Alternatively, a `TensorMap` can be constructed by specifying its data, codmain and domain
4646
in one of the following ways:
4747
```@docs
48-
TensorMap(data::AbstractDict{<:Sector,<:AbstractMatrix}, V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂}
49-
TensorMap(data::AbstractArray, V::TensorMapSpace{S,N₁,N₂}; tol) where {S<:IndexSpace,N₁,N₂}
48+
TensorMap(data::AbstractDict{<:Sector,<:AbstractMatrix}, V::TensorMapSpace)
49+
TensorMap(data::AbstractArray, V::TensorMapSpace; tol)
5050
```
5151

5252
Finally, we also support the following `Array`-like constructors

src/tensors/tensor.jl

Lines changed: 98 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -71,87 +71,109 @@ dim(t::TensorMap) = length(t.data)
7171

7272
# General TensorMap constructors
7373
#--------------------------------
74-
# undef constructors
74+
# hook for mapping input types to storage types -- to be implemented in extensions
75+
_tensormap_storagetype(::Type{A}) where {A <: AbstractArray} = _tensormap_storagetype(eltype(A))
76+
_tensormap_storagetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
77+
_tensormap_storagetype(::Type{T}) where {T <: Number} = Vector{T}
78+
79+
# utility type alias that makes constructors also work for type aliases that specify
80+
# different storage types. (i.e. CuTensorMap = _TensorMap{T, CuVector{T}, ...})
81+
# TODO: do we need a name for this and do we want to consider this as public?
82+
const _TensorMap{T, A <: DenseVector{T}, S, N₁, N₂} = TensorMap{T, S, N₁, N₂, A}
83+
const _Tensor{T, A <: DenseVector{T}, S, N} = Tensor{T, S, N, A}
84+
85+
# undef constructors:
86+
# - dispatch start with TensorMap{T}
87+
# - select A and map to _TensorMap{T, A}
88+
# - select S, N1, N2 and map to TensorMap{T,S,N1,N2,A}
7589
"""
76-
TensorMap{T}(undef, codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂})
77-
where {T,S,N₁,N₂}
90+
TensorMap{T}(undef, codomain::ProductSpace{S, N₁}, domain::ProductSpace{S, N₂}) where {T, S, N₁, N₂}
7891
TensorMap{T}(undef, codomain ← domain)
7992
TensorMap{T}(undef, domain → codomain)
80-
# expert mode: select storage type `A`
81-
TensorMap{T,S,N₁,N₂,A}(undef, codomain ← domain)
82-
TensorMap{T,S,N₁,N₂,A}(undef, domain → domain)
8393
84-
Construct a `TensorMap` with uninitialized data.
94+
Construct a `TensorMap` with uninitialized data with elements of type `T`.
8595
"""
86-
function TensorMap{T}(::UndefInitializer, V::TensorMapSpace{S, N₁, N₂}) where {T, S, N₁, N₂}
87-
return TensorMap{T, S, N₁, N₂, Vector{T}}(undef, V)
88-
end
89-
function TensorMap{T}(
90-
::UndefInitializer, codomain::TensorSpace{S}, domain::TensorSpace{S}
91-
) where {T, S}
92-
return TensorMap{T}(undef, codomain domain)
93-
end
94-
function Tensor{T}(::UndefInitializer, V::TensorSpace{S}) where {T, S}
95-
return TensorMap{T}(undef, V one(V))
96-
end
96+
TensorMap{T}(::UndefInitializer, V::TensorMapSpace) where {T} =
97+
_TensorMap{T, _tensormap_storagetype(T)}(undef, V)
98+
TensorMap{T}(::UndefInitializer, codomain::TensorSpace, domain::TensorSpace) where {T} =
99+
TensorMap{T}(undef, codomain domain)
100+
Tensor{T}(::UndefInitializer, V::TensorSpace) where {T} = TensorMap{T}(undef, V one(V))
101+
102+
# specifying storagetype, fill in other parameters
103+
"""
104+
(TensorMap{T, S, N₁, N₂, A} where {S, N₁, N₂})(undef, codomain, domain) where {T, A}
105+
(TensorMap{T, S, N₁, N₂, A} where {S, N₁, N₂})(undef, codomain ← domain) where {T, A}
106+
(TensorMap{T, S, N₁, N₂, A} where {S, N₁, N₂})(undef, domain → codomain) where {T, A}
107+
108+
Construct a `TensorMap` with uninitialized data stored as `A <: DenseVector{T}`.
109+
"""
110+
_TensorMap{T, A}(::UndefInitializer, V::TensorMapSpace) where {T, A} =
111+
TensorMap{T, spacetype(V), numout(V), numin(V), A}(undef, V)
112+
_TensorMap{T, A}(::UndefInitializer, codomain::TensorSpace, domain::TensorSpace) where {T, A} =
113+
_TensorMap{T, A}(undef, codomain domain)
114+
_Tensor{T, A}(::UndefInitializer, V::TensorSpace) where {T, A} = _TensorMap{T, A}(undef, V one(V))
97115

98116
# constructor starting from vector = independent data (N₁ + N₂ = 1 is special cased below)
99117
# documentation is captured by the case where `data` is a general array
100-
# here, we force the `T` argument to distinguish it from the more general constructor below
101-
function TensorMap{T}(
102-
data::A, V::TensorMapSpace{S, N₁, N₂}
103-
) where {T, S, N₁, N₂, A <: DenseVector{T}}
104-
return TensorMap{T, S, N₁, N₂, A}(data, V)
105-
end
106-
function TensorMap{T}(
107-
data::DenseVector{T}, codomain::TensorSpace{S}, domain::TensorSpace{S}
108-
) where {T, S}
109-
return TensorMap(data, codomain domain)
110-
end
118+
# here, we force the `T` and/or `A` argument to distinguish it from the more general constructor below
119+
TensorMap{T}(data::DenseVector{T}, V::TensorMapSpace) where {T} =
120+
_TensorMap{T, typeof(data)}(data, V)
121+
TensorMap{T}(data::DenseVector{T}, codomain::TensorSpace, domain::TensorSpace) where {T} =
122+
TensorMap{T}(data, codomain domain)
123+
124+
_TensorMap{T, A}(data::DenseVector{T}, V::TensorMapSpace) where {T, A} =
125+
TensorMap{T, spacetype(V), numout(V), numin(V), A}(data, V)
126+
_TensorMap{T, A}(data::DenseVector{T}, codomain::TensorSpace, domain::TensorSpace) where {T, A} =
127+
_TensorMap{T, A}(data, codomain domain)
111128

112129
# constructor starting from block data
130+
const _BlockData{I <: Sector, A <: AbstractMatrix} = AbstractDict{I, A}
131+
113132
"""
114-
TensorMap(data::AbstractDict{<:Sector,<:AbstractMatrix}, codomain::ProductSpace{S,N₁},
115-
domain::ProductSpace{S,N₂}) where {S<:ElementarySpace,N₁,N₂}
133+
TensorMap(data::AbstractDict{<:Sector, <:AbstractMatrix}, codomain::ProductSpace, domain::ProductSpace)
116134
TensorMap(data, codomain ← domain)
117135
TensorMap(data, domain → codomain)
118136
119137
Construct a `TensorMap` by explicitly specifying its block data.
120138
121139
## Arguments
122-
- `data::AbstractDict{<:Sector,<:AbstractMatrix}`: dictionary containing the block data for
140+
- `data::AbstractDict{<:Sector, <:AbstractMatrix}`: dictionary containing the block data for
123141
each coupled sector `c` as a matrix of size `(blockdim(codomain, c), blockdim(domain, c))`.
124-
- `codomain::ProductSpace{S,N₁}`: the codomain as a `ProductSpace` of `N₁` spaces of type
125-
`S<:ElementarySpace`.
126-
- `domain::ProductSpace{S,N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type
127-
`S<:ElementarySpace`.
142+
- `codomain::ProductSpace{S, N₁}`: the codomain as a `ProductSpace` of `N₁` spaces of type
143+
`S <: ElementarySpace`.
144+
- `domain::ProductSpace{S, N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type
145+
`S <: ElementarySpace`.
128146
129147
Alternatively, the domain and codomain can be specified by passing a [`HomSpace`](@ref)
130148
using the syntax `codomain ← domain` or `domain → codomain`.
131149
"""
132-
function TensorMap(
133-
data::AbstractDict{<:Sector, <:AbstractMatrix}, V::TensorMapSpace{S, N₁, N₂}
134-
) where {S, N₁, N₂}
135-
T = eltype(valtype(data))
136-
t = TensorMap{T}(undef, V)
150+
function TensorMap(data::_BlockData, V::TensorMapSpace)
151+
A = _tensormap_storagetype(valtype(data))
152+
return _TensorMap{scalartype(A), A}(data, V)
153+
end
154+
TensorMap(data::_BlockData, codom::TensorSpace, dom::TensorSpace) =
155+
TensorMap(data, codom dom)
156+
157+
function _TensorMap{T, A}(data::_BlockData, V::TensorMapSpace) where {T, A}
158+
t = _TensorMap{T, A}(undef, V)
159+
160+
# check that there aren't too many blocks
161+
for (c, b) in data
162+
c blocksectors(t) || isempty(b) || throw(SectorMismatch("data for block sector $c not expected"))
163+
end
164+
165+
# fill in the blocks -- rely on conversion in copy
137166
for (c, b) in blocks(t)
138167
haskey(data, c) || throw(SectorMismatch("no data for block sector $c"))
139168
datac = data[c]
140-
size(datac) == size(b) ||
141-
throw(DimensionMismatch("wrong size of block for sector $c"))
169+
size(datac) == size(b) || throw(DimensionMismatch("wrong size of block for sector $c"))
142170
copy!(b, datac)
143171
end
144-
for (c, b) in data
145-
c blocksectors(t) || isempty(b) ||
146-
throw(SectorMismatch("data for block sector $c not expected"))
147-
end
172+
148173
return t
149174
end
150-
function TensorMap(
151-
data::AbstractDict{<:Sector, <:AbstractMatrix}, codom::TensorSpace{S}, dom::TensorSpace{S}
152-
) where {S}
153-
return TensorMap(data, codom dom)
154-
end
175+
_TensorMap{T, A}(data::_BlockData, codom::TensorSpace, dom::TensorSpace) where {T, A} =
176+
_TensorMap{T, A}(data, codom dom)
155177

156178
@doc """
157179
zeros([T=Float64,], codomain::ProductSpace{S,N₁}, domain::ProductSpace{S,N₂}) where {S,N₁,N₂,T}
@@ -317,49 +339,45 @@ cases.
317339
to a plain array is possible, and only in the case where the `data` actually respects
318340
the specified symmetry structure, up to a tolerance `tol`.
319341
"""
320-
function TensorMap(
321-
data::AbstractArray, V::TensorMapSpace{S, N₁, N₂};
322-
tol = sqrt(eps(real(float(eltype(data)))))
323-
) where {S <: IndexSpace, N₁, N₂}
342+
function TensorMap(data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data))))))
324343
T = eltype(data)
325-
if ndims(data) == 1 && length(data) == dim(V)
326-
if data isa DenseVector # refer to specific data-capturing constructor
327-
return TensorMap{T}(data, V)
328-
else
329-
return TensorMap{T}(collect(data), V)
330-
end
331-
end
344+
A = _tensormap_storagetype(typeof(data))
345+
return _TensorMap{T, A}(data, V; tol)
346+
end
347+
function _TensorMap{T, A}(
348+
data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data)))))
349+
) where {T, A}
350+
# refer to specific data-capturing constructors if input is a vector of the correct length
351+
ndims(data) == 1 && length(data) == dim(V) && return _TensorMap{T, A}(data, V)
332352

333353
# dimension check
334-
codom = codomain(V)
335-
dom = domain(V)
354+
codom, dom = codomain(V), domain(V)
336355
arraysize = dims(V)
337356
matsize = (dim(codom), dim(dom))
357+
(size(data) == arraysize || size(data) == matsize) || throw(DimensionMismatch())
338358

339-
if !(size(data) == arraysize || size(data) == matsize)
340-
throw(DimensionMismatch())
341-
end
342-
343-
if sectortype(S) === Trivial # refer to same method, but now with vector argument
344-
return TensorMap(reshape(data, length(data)), V)
359+
if sectortype(V) === Trivial # refer to same method, but now with vector argument
360+
return _TensorMap{T, A}(reshape(data, length(data)), V)
345361
end
346362

347-
t = TensorMap{T}(undef, codom, dom)
363+
t = _TensorMap{T, A}(undef, V)
348364
arraydata = reshape(collect(data), arraysize)
349365
t = project_symmetric!(t, arraydata)
350366
if !isapprox(arraydata, convert(Array, t); atol = tol)
351367
throw(ArgumentError("Data has non-zero elements at incompatible positions"))
352368
end
353369
return t
354370
end
355-
function TensorMap(
356-
data::AbstractArray, codom::TensorSpace{S}, dom::TensorSpace{S}; kwargs...
357-
) where {S}
358-
return TensorMap(data, codom dom; kwargs...)
359-
end
360-
function Tensor(data::AbstractArray, codom::TensorSpace, ; kwargs...)
361-
return TensorMap(data, codom one(codom); kwargs...)
362-
end
371+
372+
TensorMap(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) =
373+
TensorMap(data, codom dom; kwargs...)
374+
_TensorMap{T, A}(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) where {T, A} =
375+
_TensorMap(data, codom dom; kwargs...)
376+
377+
Tensor(data::AbstractArray, codom::TensorSpace; kwargs...) =
378+
TensorMap(data, codom one(codom); kwargs...)
379+
_Tensor{T, A}(data::AbstractArray, codom::TensorSpace; kwargs...) where {T, A} =
380+
_TensorMap{T, A}(data, codom one(codom); kwargs...)
363381

364382
"""
365383
project_symmetric!(t::TensorMap, data::DenseArray) -> TensorMap

0 commit comments

Comments
 (0)