Skip to content

Commit 6f92aeb

Browse files
authored
Introduce SectorVector (#48)
1 parent 3b7447b commit 6f92aeb

4 files changed

Lines changed: 74 additions & 15 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GradedArrays"
22
uuid = "bc96ca6e-b7c8-4bb6-888e-c93f838762c2"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.4.14"
4+
version = "0.4.15"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"

src/sectorunitrange.jl

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
# This files defines SectorUnitRange, a unit range associated with a sector and an arrow
22

3+
struct SectorVector{T,Sector,Values<:AbstractVector{T}} <: AbstractVector{T}
4+
sector::Sector
5+
values::Values
6+
isdual::Bool
7+
end
8+
9+
sector(sv::SectorVector) = sv.sector
10+
ungrade(sv::SectorVector) = sv.values
11+
isdual(sv::SectorVector) = sv.isdual
12+
13+
function sectorvector(s, v::AbstractVector, b::Bool=false)
14+
return SectorVector(to_sector(s), v, b)
15+
end
16+
Base.length(sv::SectorVector) = length(sv.values)
17+
Base.size(sv::SectorVector) = (length(sv),)
18+
function Base.axes(sv::SectorVector)
19+
(sectorrange(sector(sv), only(axes(ungrade(sv))), isdual(sv)),)
20+
end
21+
Base.getindex(sv::SectorVector, i::Integer) = ungrade(sv)[i]
22+
323
# ===================================== Definition =======================================
424

525
# This implementation contains the "full range"
@@ -51,21 +71,34 @@ Base.iterate(sr::SectorUnitRange) = iterate(ungrade(sr))
5171
Base.iterate(sr::SectorUnitRange, i::Integer) = iterate(ungrade(sr), i)
5272

5373
Base.length(sr::SectorUnitRange) = length(ungrade(sr))
74+
Base.size(sr::SectorUnitRange) = (length(sr),)
75+
function Base.axes(sr::SectorUnitRange)
76+
(sectorrange(sector(sr), only(axes(ungrade(sr))), isdual(sr)),)
77+
end
5478

5579
Base.last(sr::SectorUnitRange) = last(ungrade(sr))
5680

5781
# slicing
5882
Base.getindex(sr::SectorUnitRange, i::Integer) = ungrade(sr)[i]
5983

60-
function Base.getindex(sr::SectorUnitRange, r::AbstractUnitRange{T}) where {T<:Integer}
61-
return sr[SymmetryStyle(sr), r]
84+
function Base.getindex(sr::SectorUnitRange, I::AbstractVector{<:Integer})
85+
return sr[SymmetryStyle(sr), I]
6286
end
63-
function Base.getindex(sr::SectorUnitRange, ::NotAbelianStyle, r::AbstractUnitRange)
87+
function Base.getindex(sr::SectorUnitRange, I::AbstractUnitRange{<:Integer})
88+
return sr[SymmetryStyle(sr), I]
89+
end
90+
function Base.getindex(sr::SectorUnitRange, ::NotAbelianStyle, r::AbstractVector{<:Integer})
6491
return ungrade(sr)[r]
6592
end
66-
function Base.getindex(sr::SectorUnitRange, ::AbelianStyle, r::AbstractUnitRange)
93+
function Base.getindex(sr::SectorUnitRange, ::AbelianStyle, r::AbstractUnitRange{<:Integer})
6794
return sectorrange(sector(sr), ungrade(sr)[ungrade(r)], isdual(sr))
6895
end
96+
function Base.getindex(sr::SectorUnitRange, ::AbelianStyle, r::AbstractVector{<:Integer})
97+
return sectorvector(sector(sr), ungrade(sr)[ungrade(r)], isdual(sr))
98+
end
99+
function Base.getindex(sr::SectorUnitRange, I::AbstractVector{Bool})
100+
return sr[to_indices(sr, (I,))...]
101+
end
69102

70103
# TODO replace (:,x) indexing with kronecker(:, x)
71104
Base.getindex(sr::SectorUnitRange, t::Tuple{Colon,<:Integer}) = sr[(:, last(t):last(t))]
@@ -114,3 +147,7 @@ sector_type(::Type{<:SectorUnitRange{T,Sector}}) where {T,Sector} = Sector
114147
# TBD error for non-integer?
115148
sector_multiplicity(sr::SectorUnitRange) = length(sr) ÷ length(sector(sr))
116149
sector_multiplicities(sr::SectorUnitRange) = [sector_multiplicity(sr)] # TBD remove?
150+
151+
function Base.similar(A::Type{<:AbstractArray}, ax::Tuple{SectorOneTo,Vararg{SectorOneTo}})
152+
return similar(A, ungrade.(ax))
153+
end

test/test_sectorunitrange.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Test: @test, @test_throws, @testset
2-
31
using BlockArrays:
42
Block,
53
BlockBoundsError,
@@ -11,12 +9,12 @@ using BlockArrays:
119
blockisequal,
1210
blocks,
1311
findblock
14-
using TestExtras: @constinferred
15-
1612
using GradedArrays:
1713
U1,
1814
SU,
15+
SectorOneTo,
1916
SectorUnitRange,
17+
SectorVector,
2018
dual,
2119
flip,
2220
isdual,
@@ -29,6 +27,8 @@ using GradedArrays:
2927
sectors,
3028
space_isequal,
3129
ungrade
30+
using Test: @test, @test_throws, @testset
31+
using TestExtras: @constinferred
3232

3333
@testset "SectorUnitRange" begin
3434
sr = sectorrange(SU((1, 0)), 2)
@@ -49,7 +49,8 @@ using GradedArrays:
4949
@test eltype(sr) === Int
5050
@test step(sr) == 1
5151
@test eachindex(sr) == Base.oneto(6)
52-
@test only(axes(sr)) isa Base.OneTo
52+
@test only(axes(sr)) isa SectorOneTo
53+
@test sector(only(axes(sr))) == sector(sr)
5354
@test only(axes(sr)) == 1:6
5455
@test iterate(sr) == (1, 1)
5556
for i in 1:5
@@ -148,6 +149,11 @@ using GradedArrays:
148149
@test (@constinferred getindex(srab, 2:2)) isa SectorUnitRange
149150
@test space_isequal(srab[2:2], sectorrange(U1(1), 2:2))
150151
@test space_isequal(dual(srab)[2:2], sectorrange(U1(1), 2:2, true))
152+
@test srab[[1, 3]] isa SectorVector{Int}
153+
@test sector(srab[[1, 3]]) == sector(srab)
154+
@test ungrade(srab[[1, 3]]) == [1, 3]
155+
@test length(srab[[1, 3]]) == 2
156+
@test space_isequal(only(axes(srab[[1, 3]])), sectorrange(U1(1), 2))
151157

152158
# Slice sector range with sector range
153159
sr1 = sectorrange(U1(1), 4)

test/test_tensor_product.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,46 @@ using TensorProducts: ⊗, OneToOne, tensor_product
1818
using Test: @test, @testset
1919
using TestExtras: @constinferred
2020

21-
GradedArrays.SymmetryStyle(::Type{<:String}) = NotAbelianStyle()
22-
GradedArrays.tensor_product(s1::String, s2::String) = gradedrange([s1 * s2 => 1])
21+
struct NotAbelianString
22+
str::String
23+
end
24+
Base.:*(s1::NotAbelianString, s2::NotAbelianString) = NotAbelianString(s1.str * s2.str)
25+
GradedArrays.SymmetryStyle(::Type{<:NotAbelianString}) = NotAbelianStyle()
26+
function GradedArrays.tensor_product(s1::NotAbelianString, s2::NotAbelianString)
27+
gradedrange([s1 * s2 => 1])
28+
end
29+
Base.length(s::NotAbelianString) = length(s.str)
2330

2431
@testset "unmerged_tensor_product" begin
2532
@test unmerged_tensor_product() isa OneToOne
2633
@test unmerged_tensor_product(OneToOne(), OneToOne()) isa OneToOne
2734
@test unmerged_tensor_product(1:1, 1:1) == 1:1
2835
@test sectormergesort(1:1) isa UnitRange
2936

30-
a = gradedrange(["x" => 2, "y" => 3])
37+
a = gradedrange([NotAbelianString("x") => 2, NotAbelianString("y") => 3])
3138
@test space_isequal(unmerged_tensor_product(a), a)
3239

3340
b = unmerged_tensor_product(a, a)
3441
@test b isa GradedOneTo
3542
@test length(b) == 50
3643
@test blocklength(b) == 4
3744
@test blocklengths(b) == [8, 12, 12, 18]
38-
@test space_isequal(b, gradedrange(["xx" => 4, "yx" => 6, "xy" => 6, "yy" => 9]))
45+
@test space_isequal(
46+
b,
47+
gradedrange([
48+
NotAbelianString("xx") => 4,
49+
NotAbelianString("yx") => 6,
50+
NotAbelianString("xy") => 6,
51+
NotAbelianString("yy") => 9,
52+
]),
53+
)
3954

4055
c = unmerged_tensor_product(a, a, a)
4156
@test c isa GradedOneTo
4257
@test length(c) == 375
4358
@test blocklength(c) == 8
44-
@test sectors(c) == ["xxx", "yxx", "xyx", "yyx", "xxy", "yxy", "xyy", "yyy"]
59+
@test sectors(c) ==
60+
NotAbelianString.(["xxx", "yxx", "xyx", "yyx", "xxy", "yxy", "xyy", "yyy"])
4561

4662
a = gradedrange([U1(1) => 1, U1(2) => 3, U1(1) => 1])
4763
@test space_isequal(

0 commit comments

Comments
 (0)