Skip to content

Commit 76c3a0c

Browse files
committed
Try to finagle ones and zeros again
1 parent 51fea05 commit 76c3a0c

4 files changed

Lines changed: 20 additions & 13 deletions

File tree

ext/TensorKitCUDAExt/cutensormap.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
3737
fill!(t, $felt(T))
3838
return t
3939
end
40-
function Base.$fname(
41-
::Type{TA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
42-
) where {TA <: CuArray, S <: IndexSpace}
43-
return CUDA.$fname(eltype(TA), codomain domain)
44-
end
4540
end
4641
end
4742

src/tensors/tensor.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,9 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one))
306306
return Base.$fname(codomain domain)
307307
end
308308
function Base.$fname(
309-
::Type{T}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
310-
) where {T <: Number, S <: IndexSpace}
311-
return Base.$fname(T, codomain domain)
312-
end
313-
function Base.$fname(
314-
::Type{TA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
315-
) where {TA <: Array, S <: IndexSpace}
316-
return Base.$fname(eltype(TA), codomain domain)
309+
::Type{TorA}, codomain::TensorSpace{S}, domain::TensorSpace{S} = one(codomain)
310+
) where {TorA, S <: IndexSpace}
311+
return Base.$fname(TorA, codomain domain)
317312
end
318313
Base.$fname(V::TensorMapSpace) = Base.$fname(Float64, V)
319314
function Base.$fname(::Type{TorA}, V::TensorMapSpace) where {TorA}

test/cuda/tensors.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ for V in spacelist
5555
@test domain(t) == one(W)
5656
@test typeof(t) == TensorMap{Float64, spacetype(t), 5, 0, CuVector{Float64, CUDA.DeviceMemory}}
5757
end
58+
for f in (Base.ones, Base.zeros)
59+
t = @constinferred f(CuVector{Float64}, W)
60+
@test scalartype(t) == Float64
61+
@test codomain(t) == W
62+
@test space(t) == (W one(W))
63+
@test domain(t) == one(W)
64+
@test typeof(t) == TensorMap{Float64, spacetype(t), 5, 0, CuVector{Float64, CUDA.DeviceMemory}}
65+
end
5866
for f in (rand, randn)
5967
t = @constinferred f(CuVector{Float64, CUDA.DeviceMemory}, W)
6068
@test scalartype(t) == Float64

test/tensors/tensors.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ for V in spacelist
4444
@test space(t) == (W one(W))
4545
@test domain(t) == one(W)
4646
@test typeof(t) == TensorMap{T, spacetype(t), 5, 0, Vector{T}}
47+
# Array type input
48+
t = @constinferred zeros(Vector{T}, W)
49+
@test @constinferred(hash(t)) == hash(deepcopy(t))
50+
@test scalartype(t) == T
51+
@test norm(t) == 0
52+
@test codomain(t) == W
53+
@test space(t) == (W one(W))
54+
@test domain(t) == one(W)
55+
@test typeof(t) == TensorMap{T, spacetype(t), 5, 0, Vector{T}}
4756
# blocks
4857
bs = @constinferred blocks(t)
4958
if !isempty(blocksectors(t)) # multifusion space ending on module gives empty data

0 commit comments

Comments
 (0)