Skip to content

Commit 763ede0

Browse files
committed
More fixes
1 parent a268464 commit 763ede0

File tree

3 files changed

+14
-40
lines changed

3 files changed

+14
-40
lines changed

Project.toml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,12 @@ TensorKitMooncakeExt = "Mooncake"
3636
[workspace]
3737
projects = ["test"]
3838

39-
[extras]
40-
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
41-
4239
[compat]
4340
Adapt = "4"
4441
CUDA = "5.9"
4542
ChainRulesCore = "1"
4643
Dictionaries = "0.4"
4744
FiniteDifferences = "0.12"
48-
GPUArrays = "<11.5.0"
4945
LRUCache = "1.0.2"
5046
LinearAlgebra = "1"
5147
MatrixAlgebraKit = "0.6.5"

src/tensors/braidingtensor.jl

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
1717
V1::S
1818
V2::S
1919
adjoint::Bool
20-
function BraidingTensor{T, S, A}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}}
20+
function BraidingTensor{T, S, A}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace, A <: DenseVector{T}}
2121
for a in sectors(V1), b in sectors(V2), c in (a b)
2222
Nsymbol(a, b, c) == Nsymbol(b, a, c) ||
2323
throw(ArgumentError("Cannot define a braiding between $a and $b"))
@@ -26,45 +26,27 @@ struct BraidingTensor{T, S, A} <: AbstractTensorMap{T, S, 2, 2}
2626
# partial construction: only construct rowr and colr when needed
2727
end
2828
end
29-
function BraidingTensor{T, S}(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {T, S <: IndexSpace, A}
30-
return BraidingTensor{T, S, A}(V1, V2, A, adjoint)
31-
end
32-
function BraidingTensor{T}(V1::S, V2::S, A, adjoint::Bool = false) where {T, S <: IndexSpace}
33-
return BraidingTensor{T, S}(V1, V2, A, adjoint)
34-
end
3529
function BraidingTensor{T}(V1::S, V2::S, adjoint::Bool = false) where {T, S <: IndexSpace}
36-
return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint)
37-
end
38-
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, A, adjoint::Bool = false) where {T}
39-
return BraidingTensor{T}(promote(V1, V2)..., A, adjoint)
30+
return BraidingTensor{T, S, Vector{T}}(V1, V2, adjoint)
4031
end
41-
function BraidingTensor{T}(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false) where {T}
42-
return BraidingTensor{T}(V1, V2, Vector{T}, adjoint)
43-
end
44-
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{A}, adjoint::Bool = false) where {T, A <: DenseVector{T}}
45-
return BraidingTensor{T}(promote(V1, V2)..., A, adjoint)
32+
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
33+
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
34+
return BraidingTensor{T, S, Vector{T}}(V1, V2, Vector{T}, adjoint)
4635
end
47-
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, ::Type{T}, adjoint::Bool = false) where {T}
48-
return BraidingTensor{T}(promote(V1, V2)..., Vector{T}, adjoint)
36+
# necessary due to HomSpace ctor below
37+
function BraidingTensor(V1::S, V2::S, A, adjoint::Bool = false) where {S <: IndexSpace}
38+
T = eltype(A)
39+
return BraidingTensor{T, S, A}(V1, V2, adjoint)
4940
end
5041
function BraidingTensor(V1::IndexSpace, V2::IndexSpace, adjoint::Bool = false)
5142
return BraidingTensor(promote(V1, V2)..., adjoint)
5243
end
53-
function BraidingTensor(V1::S, V2::S, adjoint::Bool = false) where {S <: IndexSpace}
54-
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
55-
return BraidingTensor{T, S}(V1, V2, Vector{T}, adjoint)
56-
end
57-
function BraidingTensor(V1::S, V2::S, ::Type{A}, adjoint::Bool = false) where {S <: IndexSpace, A <: AbstractArray}
58-
T = BraidingStyle(sectortype(S)) isa SymmetricBraiding ? Float64 : ComplexF64
59-
A′ = similarstoragetype(A, T)
60-
return BraidingTensor{T, S}(V1, V2, A′, adjoint)
61-
end
6244
function BraidingTensor(V::HomSpace, adjoint::Bool = false)
6345
domain(V) == reverse(codomain(V)) ||
6446
throw(SpaceMismatch("Cannot define a braiding on $V"))
6547
return BraidingTensor(V[2], V[1], adjoint)
6648
end
67-
function BraidingTensor(V::HomSpace, ::Type{A}, adjoint::Bool = false) where {A}
49+
function BraidingTensor(V::HomSpace, A, adjoint::Bool = false)
6850
domain(V) == reverse(codomain(V)) ||
6951
throw(SpaceMismatch("Cannot define a braiding on $V"))
7052
return BraidingTensor(V[2], V[1], A, adjoint)
@@ -75,7 +57,7 @@ function BraidingTensor{T}(V::HomSpace, adjoint::Bool = false) where {T}
7557
return BraidingTensor{T}(V[2], V[1], adjoint)
7658
end
7759
function Base.adjoint(b::BraidingTensor{T, S, A}) where {T, S, A}
78-
return BraidingTensor{T, S, A}(b.V1, b.V2, A, !b.adjoint)
60+
return BraidingTensor{T, S, A}(b.V1, b.V2, !b.adjoint)
7961
end
8062

8163
storagetype(::Type{BraidingTensor{T, S, A}}) where {T, S, A} = A
@@ -113,7 +95,7 @@ end
11395

11496
function _set_subblock!(data, val)
11597
f(I) = ((I[1] == I[4]) & (I[2] == I[3])) * val
116-
data .= f.(CartesianIndices(data))
98+
@. data = f(CartesianIndices(data))
11799
end
118100

119101

test/cuda/planar.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ for V in spacelist
2727
t3 = @testinferred adapt(storagetype(t2), t1)
2828
@test storagetype(t3) == storagetype(t2)
2929
# allowscalar needed for the StridedView comparison
30-
CUDA.@allowscalar begin
31-
@test t3 == t2
32-
end
30+
@test t3 t2
3331

3432
W2 = reverse(codomain(W)) domain(W)
3533
@test_throws SpaceMismatch BraidingTensor(W2)
@@ -41,9 +39,7 @@ for V in spacelist
4139
t3 = @inferred TensorMap(t2)
4240
@test storagetype(t3) == CuVector{ComplexF64, CUDA.DeviceMemory}
4341
t4 = braid(adapt(CuArray, id(scalartype(t2), domain(t2))), ((2, 1), (3, 4)), (1, 2, 3, 4))
44-
CUDA.@allowscalar begin
45-
@test t1 t4
46-
end
42+
@test t1 t4
4743
for (c, b) in blocks(t1)
4844
@test block(t1, c) b block(t3, c)
4945
end

0 commit comments

Comments
 (0)