From 1f5d851cb77fd9f794b34596a0ccad7de56e185e Mon Sep 17 00:00:00 2001 From: mtfishman Date: Wed, 25 Jun 2025 12:41:31 -0400 Subject: [PATCH 1/2] Preserve blocking of indices when slicing a unit range with a blocked unit range --- Project.toml | 2 +- src/blockaxis.jl | 20 ++++++++++++++++++++ test/test_blockindices.jl | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index edf66627..e455ea42 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "BlockArrays" uuid = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -version = "1.6.3" +version = "1.6.4" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/blockaxis.jl b/src/blockaxis.jl index 57e14ce8..d8333202 100644 --- a/src/blockaxis.jl +++ b/src/blockaxis.jl @@ -105,6 +105,16 @@ function Base.AbstractUnitRange{T}(r::BlockedUnitRange) where {T} return _BlockedUnitRange(convert(T,first(r)), convert.(T,blocklasts(r))) end +# See: https://github.com/JuliaLang/julia/blob/b06d26075bf7b3f4e7f1b64b120f5665d8ed76f9/base/range.jl#L991-L1004 +function Base.getindex(r::AbstractUnitRange, s::AbstractBlockedUnitRange{T}) where {T<:Integer} + @boundscheck checkbounds(r, s) + + f = first(r) + start = oftype(f, f + first(s) - firstindex(r)) + lens = map(Base.Fix1(oftype, f), blocklengths(s)) + return blockedrange(start, lens) +end + """ BlockedOneTo{T, <:Union{AbstractVector{T}, NTuple{<:Any,T}}} where {T} @@ -164,6 +174,16 @@ function Base.AbstractUnitRange{T}(r::BlockedOneTo) where {T} return BlockedOneTo(convert.(T,blocklasts(r))) end +# See: https://github.com/JuliaLang/julia/blob/b06d26075bf7b3f4e7f1b64b120f5665d8ed76f9/base/range.jl#L1006-L1010 +function getindex(r::Base.OneTo{T}, s::BlockedOneTo) where T + @inline + @boundscheck checkbounds(r, s) + return BlockedOneTo(T.(blocklasts(s))) +end +function getindex(r::BlockedOneTo{T}, s::BlockedOneTo) where T + return Base.OneTo(r)[s] +end + """ blockedrange(blocklengths::Union{Tuple, AbstractVector}) blockedrange(first::Integer, blocklengths::Union{Tuple, AbstractVector}) diff --git a/test/test_blockindices.jl b/test/test_blockindices.jl index b57512be..975b655e 100644 --- a/test/test_blockindices.jl +++ b/test/test_blockindices.jl @@ -323,6 +323,40 @@ end end end + @testset "BlockedUnitRange indexing" begin + a = 2:10 + b = blockedrange(2, [1,2,3]) + @test a[b] == blockedrange(3, [1,2,3]) + @test a[b] isa BlockedUnitRange + @test first(a[b]) == 3 + @test blocklengths(a[b]) == [1,2,3] + @test_throws BoundsError a[blockedrange(5, [1,2,3])] + + a = blockedrange(2, [4,5]) + b = blockedrange(2, [1,2,3]) + @test a[b] == blockedrange(3, [1,2,3]) + @test a[b] isa BlockedUnitRange + @test first(a[b]) == 3 + @test blocklengths(a[b]) == [1,2,3] + @test_throws BoundsError a[blockedrange(5, [1,2,3])] + + a = Base.OneTo(9) + b = blockedrange([1,2,3]) + @test a[b] == blockedrange([1,2,3]) + @test a[b] isa BlockedOneTo + @test first(a[b]) == 1 + @test blocklengths(a[b]) == [1,2,3] + @test_throws BoundsError a[blockedrange([1,2,3,4])] + + a = blockedrange([4,5]) + b = blockedrange([1,2,3]) + @test a[b] == blockedrange([1,2,3]) + @test a[b] isa BlockedOneTo + @test first(a[b]) == 1 + @test blocklengths(a[b]) == [1,2,3] + @test_throws BoundsError a[blockedrange([1,2,3,4])] + end + @testset "misc" begin b = blockedrange(1, [1,2,3]) @test axes(b) == Base.unsafe_indices(b) == (b,) From 8a33d14e9104ed845d4f70a621669d931c3a121c Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 1 Jul 2025 10:40:33 -0400 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Sheehan Olver --- src/blockaxis.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blockaxis.jl b/src/blockaxis.jl index d8333202..3a3542f6 100644 --- a/src/blockaxis.jl +++ b/src/blockaxis.jl @@ -178,10 +178,10 @@ end function getindex(r::Base.OneTo{T}, s::BlockedOneTo) where T @inline @boundscheck checkbounds(r, s) - return BlockedOneTo(T.(blocklasts(s))) + return BlockedOneTo(convert(AbstractVector{T}, blocklasts(s))) end function getindex(r::BlockedOneTo{T}, s::BlockedOneTo) where T - return Base.OneTo(r)[s] + return Base.oneto(r)[s] end """