Skip to content

Commit 633f15e

Browse files
committed
uniformize into similarstoragetype
fix ambiguity more careful with storagetypes more careful with tensoroperations even more careful the carefulest!
1 parent c4fdb2d commit 633f15e

3 files changed

Lines changed: 67 additions & 30 deletions

File tree

src/tensors/abstracttensor.jl

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,57 @@ end
4545
Return the type of vector that stores the data of a tensor.
4646
""" storagetype
4747

48-
similarstoragetype(TT::Type{<:AbstractTensorMap}) = similarstoragetype(TT, scalartype(TT))
48+
# storage type determination and promotion - hooks for specializing
49+
# the default implementation tries to leverarge inference and `similar`
50+
@doc """
51+
similarstoragetype(t, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
52+
similarstoragetype(TT, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
53+
similarstoragetype(A, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
54+
similarstoragetype(D, [T = scalartype(t)]) -> Type{<:DenseVector{T}}
55+
56+
similarstoragetype(T::Type{<:Number}) -> Vector{T}
57+
58+
For a given tensor `t`, tensor type `TT <: AbstractTensorMap`, array type `A <: AbstractArray`,
59+
or sector dictionary type `D <: AbstractDict{<:Sector, <:AbstractMatrix}`, compute an appropriate
60+
storage type for tensors. Optionally, a different scalar type `T` can be supplied as well.
61+
62+
This function determines the type of newly allocated `TensorMap`s throughout TensorKit.jl.
63+
It does so by leveraging type inference and calls to `Base.similar` for automatically determining
64+
appropriate storage types. Additionally this registers the default storage type when only a type
65+
`T <: Number` is provided, which is `Vector{T}`.
66+
""" similarstoragetype
67+
68+
# implement in type domain
69+
similarstoragetype(t) = similarstoragetype(typeof(t))
70+
similarstoragetype(t, ::Type{T}) where {T <: Number} = similarstoragetype(typeof(t), T)
71+
72+
# avoid infinite recursion
73+
similarstoragetype(X::Type) =
74+
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X`"))
75+
similarstoragetype(X::Type, ::Type{T}) where {T <: Number} =
76+
throw(ArgumentError("Cannot determine a storagetype for tensor / array type `$X` and/or scalar type `$T`"))
77+
78+
# implement on tensors
79+
similarstoragetype(::Type{TT}) where {TT <: AbstractTensorMap} = similarstoragetype(storagetype(TT))
80+
similarstoragetype(::Type{TT}, ::Type{T}) where {TT <: AbstractTensorMap, T <: Number} =
81+
similarstoragetype(storagetype(TT), T)
82+
83+
# implement on arrays
84+
similarstoragetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
85+
Base.@assume_effects :foldable similarstoragetype(::Type{A}) where {A <: AbstractArray{<:Number}} =
86+
Core.Compiler.return_type(similar, Tuple{A, Int})
87+
Base.@assume_effects :foldable similarstoragetype(::Type{A}, ::Type{T}) where {A <: AbstractArray, T <: Number} =
88+
Core.Compiler.return_type(similar, Tuple{A, Type{T}, Int})
89+
90+
# implement on sectordicts
91+
similarstoragetype(::Type{D}) where {D <: AbstractDict{<:Sector, <:AbstractMatrix}} =
92+
similarstoragetype(valtype(D))
93+
similarstoragetype(::Type{D}, ::Type{T}) where {D <: AbstractDict{<:Sector, <:AbstractMatrix}, T <: Number} =
94+
similarstoragetype(valtype(D), T)
95+
96+
# default storage type for numbers
97+
similarstoragetype(::Type{T}) where {T <: Number} = Vector{T}
4998

50-
function similarstoragetype(TT::Type{<:AbstractTensorMap}, ::Type{T}) where {T}
51-
return Core.Compiler.return_type(similar, Tuple{storagetype(TT), Type{T}})
52-
end
5399

54100
# tensor characteristics: space and index information
55101
#-----------------------------------------------------
@@ -175,7 +221,6 @@ end
175221
InnerProductStyle(t::AbstractTensorMap) = InnerProductStyle(typeof(t))
176222
storagetype(t::AbstractTensorMap) = storagetype(typeof(t))
177223
blocktype(t::AbstractTensorMap) = blocktype(typeof(t))
178-
similarstoragetype(t::AbstractTensorMap, T = scalartype(t)) = similarstoragetype(typeof(t), T)
179224

180225
numout(t::AbstractTensorMap) = numout(typeof(t))
181226
numin(t::AbstractTensorMap) = numin(typeof(t))
@@ -503,17 +548,17 @@ end
503548

504549
# 3 arguments
505550
Base.similar(t::AbstractTensorMap, codomain::TensorSpace, domain::TensorSpace) =
506-
similar(t, similarstoragetype(t), codomain domain)
551+
similar(t, similarstoragetype(t, scalartype(t)), codomain domain)
507552
Base.similar(t::AbstractTensorMap, ::Type{T}, codomain::TensorSpace) where {T} =
508553
similar(t, T, codomain one(codomain))
509554

510555
# 2 arguments
511556
Base.similar(t::AbstractTensorMap, codomain::TensorSpace) =
512-
similar(t, similarstoragetype(t), codomain one(codomain))
513-
Base.similar(t::AbstractTensorMap, V::TensorMapSpace) = similar(t, similarstoragetype(t), V)
557+
similar(t, codomain one(codomain))
558+
Base.similar(t::AbstractTensorMap, V::TensorMapSpace) = similar(t, scalartype(t), V)
514559
Base.similar(t::AbstractTensorMap, ::Type{T}) where {T} = similar(t, T, space(t))
515560
# 1 argument
516-
Base.similar(t::AbstractTensorMap) = similar(t, similarstoragetype(t), space(t))
561+
Base.similar(t::AbstractTensorMap) = similar(t, scalartype(t), space(t))
517562

518563
# generic implementation for AbstractTensorMap -> returns `TensorMap`
519564
function Base.similar(t::AbstractTensorMap, ::Type{TorA}, V::TensorMapSpace) where {TorA}
@@ -524,7 +569,7 @@ end
524569

525570
# implementation in type-domain
526571
function Base.similar(::Type{TT}, V::TensorMapSpace) where {TT <: AbstractTensorMap}
527-
TT′ = tensormaptype(spacetype(V), numout(V), numin(V), similarstoragetype(TT))
572+
TT′ = tensormaptype(spacetype(V), numout(V), numin(V), similarstoragetype(TT, scalartype(TT)))
528573
return TT′(undef, V)
529574
end
530575
Base.similar(::Type{TT}, cod::TensorSpace, dom::TensorSpace) where {TT <: AbstractTensorMap} =

src/tensors/tensor.jl

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,11 @@ i.e. a tensor map with only a non-trivial output space.
4848
"""
4949
const Tensor{T, S, N, A} = TensorMap{T, S, N, 0, A}
5050

51-
function tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type)
52-
A = _tensormap_storagetype(TorA)
53-
A <: DenseVector || throw(ArgumentError("Cannot determine a valid storage type from argument $TorA"))
51+
function tensormaptype(::Type{S}, N₁, N₂, ::Type{TorA}) where {S <: IndexSpace, TorA}
52+
A = similarstoragetype(TorA)
5453
return TensorMap{scalartype(A), S, N₁, N₂, A}
5554
end
5655

57-
# hook for mapping input types to storage types -- to be implemented in extensions
58-
_tensormap_storagetype(::Type{A}) where {A <: DenseVector{<:Number}} = A
59-
_tensormap_storagetype(::Type{A}) where {A <: Array} = _tensormap_storagetype(scalartype(A))
60-
_tensormap_storagetype(::Type{T}) where {T <: Number} = Vector{T}
61-
function _tensormap_storagetype(::Type{A}) where {A <: AbstractArray}
62-
PA = parenttype(A)
63-
PA === A && throw(MethodError(_tensormap_storagetype, A)) # avoid infinite recursion
64-
return _tensormap_storagetype(PA)
65-
end
66-
6756
# Basic methods for characterising a tensor:
6857
#--------------------------------------------
6958
space(t::TensorMap) = t.space
@@ -201,9 +190,8 @@ cases.
201190
the specified symmetry structure, up to a tolerance `tol`.
202191
"""
203192
function TensorMap(data::AbstractArray, V::TensorMapSpace; tol = sqrt(eps(real(float(eltype(data))))))
204-
T = eltype(data)
205-
A = _tensormap_storagetype(typeof(data))
206-
return TensorMapWithStorage{T, A}(data, V; tol)
193+
A = similarstoragetype(data)
194+
return TensorMapWithStorage{scalartype(A), A}(data, V; tol)
207195
end
208196
TensorMap(data::AbstractArray, codom::TensorSpace, dom::TensorSpace; kwargs...) =
209197
TensorMap(data, codom dom; kwargs...)
@@ -259,7 +247,7 @@ Construct a `TensorMap` by explicitly specifying its block data.
259247
- `domain::ProductSpace{S, N₂}`: the domain as a `ProductSpace` of `N₂` spaces of type `S <: ElementarySpace`.
260248
"""
261249
function TensorMap(data::_BlockData, V::TensorMapSpace)
262-
A = _tensormap_storagetype(valtype(data))
250+
A = similarstoragetype(data)
263251
return TensorMapWithStorage{scalartype(A), A}(data, V)
264252
end
265253
TensorMap(data::_BlockData, codom::TensorSpace, dom::TensorSpace) =

src/tensors/tensoroperations.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,16 @@ function TO.tensorcontract_type(
155155
) where {N₁, N₂}
156156
spacetype(A) == spacetype(B) || throw(SpaceMismatch("incompatible space types"))
157157
I = sectortype(A)
158-
M = similarstoragetype(A, sectorscalartype(I) <: Real ? TC : complex(TC))
159-
MB = similarstoragetype(B, sectorscalartype(I) <: Real ? TC : complex(TC))
160-
M == MB || throw(ArgumentError("incompatible storage types:\n$(M)$(MB)"))
158+
TC′ = isreal(I) ? TC : complex(TC)
159+
M = promote_storagetype(similarstoragetype(A, TC′), similarstoragetype(B, TC′))
161160
return tensormaptype(spacetype(A), N₁, N₂, M)
162161
end
163162

163+
# TODO: handle actual promotion rule system
164+
function promote_storagetype(::Type{M₁}, ::Type{M₂}) where {M₁, M₂}
165+
return M₁ === M₂ ? M₁ : throw(ArgumentError("Cannot determine storage type for combining `$M₁` and `$M₂`"))
166+
end
167+
164168
function TO.tensorcontract_structure(
165169
A::AbstractTensorMap, pA::Index2Tuple, conjA::Bool,
166170
B::AbstractTensorMap, pB::Index2Tuple, conjB::Bool,

0 commit comments

Comments
 (0)