Skip to content

Commit 4a73013

Browse files
committed
BlockArraysExt
1 parent 7d4aff9 commit 4a73013

5 files changed

Lines changed: 39 additions & 3 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@ authors = ["ITensor developers <support@itensor.org> and contributors"]
44
version = "0.3.1"
55

66
[weakdeps]
7+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
78
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89

910
[extensions]
11+
FunctionImplementationsBlockArraysExt = "BlockArrays"
1012
FunctionImplementationsLinearAlgebraExt = "LinearAlgebra"
1113

1214
[compat]
15+
BlockArrays = "1.4"
1316
LinearAlgebra = "1.10"
1417
julia = "1.10"
1518

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module FunctionImplementationsBlockArraysExt
2+
3+
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
4+
using FunctionImplementations.Concatenate: Concatenate
5+
6+
function Concatenate.cat_axis(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
7+
first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1"))
8+
return blockedrange([blocklengths(a1); blocklengths(a2)])
9+
end
10+
11+
end

src/concatenate.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,19 @@ end
8989

9090
function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T}
9191
# Convert to a broadcasted to leverage its similar implementation.
92-
bc = BC.Broadcasted(style(concat), nothing, ())
93-
return similar(bc, T, ax)
92+
bc = BC.Broadcasted(style(concat), identity, concat.args, ax)
93+
return similar(bc, T)
9494
end
9595

9696
function cat_axis(
9797
a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange...
9898
)
9999
return cat_axis(cat_axis(a1, a2), a_rest...)
100100
end
101-
cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2))
101+
function cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange)
102+
first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1"))
103+
return Base.OneTo(length(a1) + length(a2))
104+
end
102105

103106
function cat_ndims(dims, as::AbstractArray...)
104107
return max(maximum(dims), maximum(ndims, as))

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
45
FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c"
56
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -14,6 +15,7 @@ FunctionImplementations = {path = ".."}
1415
[compat]
1516
Adapt = "4"
1617
Aqua = "0.8"
18+
BlockArrays = "1.4"
1719
FunctionImplementations = "0.3"
1820
JLArrays = "0.3"
1921
LinearAlgebra = "1.10"

test/test_blockarraysext.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using BlockArrays: BlockArray, blockedrange, blockisequal
2+
using FunctionImplementations.Concatenate: concatenate
3+
using Test: @test, @testset
4+
5+
@testset "BlockArraysExt" begin
6+
a = BlockArray(randn(4, 4), [2, 2], [2, 2])
7+
b = BlockArray(randn(4, 4), [2, 2], [2, 2])
8+
9+
concat = concatenate(1, a, b)
10+
@test axes(concat) == (Base.OneTo(8), Base.OneTo(4))
11+
@test blockisequal(axes(concat, 1), blockedrange([2, 2, 2, 2]))
12+
@test blockisequal(axes(concat, 2), blockedrange([2, 2]))
13+
@test size(concat) == (8, 4)
14+
@test eltype(concat) Float64
15+
@test copy(concat) == cat(a, b; dims = 1)
16+
@test copy(concat) isa BlockArray{Float64, 2}
17+
end

0 commit comments

Comments
 (0)