Skip to content

Commit a588abe

Browse files
committed
Fix @static_unpack and add JET type stability tests
- Fix _maybe_SArray functions in static_arrays.jl to properly construct SVector and SArray types by including element type parameter and using Tuple(x) conversion for type stability - Add JET.jl as test dependency - Add JET type stability tests for core ComponentArray operations - The @static_unpack macro now correctly returns SVector and SMatrix types 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 8ab9830 commit a588abe

4 files changed

Lines changed: 44 additions & 4 deletions

File tree

src/compat/static_arrays.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ function ComponentArray{A}(::UndefInitializer, ax::Axes) where {
33
return ComponentArray(similar(A), ax...)
44
end
55

6-
_maybe_SArray(x::SubArray, ::Val{N}, ::FlatAxis) where {N} = SVector{N}(x)
7-
function _maybe_SArray(x::Base.ReshapedArray, ::Val, ::ShapedAxis{Sz}) where {Sz}
8-
SArray{Tuple{Sz...}}(x)
6+
function _maybe_SArray(x::SubArray{T}, ::Val{N}, ::FlatAxis) where {T, N}
7+
SVector{N, T}(Tuple(x))
8+
end
9+
function _maybe_SArray(x::Base.ReshapedArray{T, N}, ::Val, ::ShapedAxis{Sz}) where {T, N, Sz}
10+
SArray{Tuple{Sz...}, T, N, prod(Sz)}(Tuple(x))
11+
end
12+
function _maybe_SArray(x::AbstractArray{T}, ::Val, ::Shaped1DAxis{Sz}) where {T, Sz}
13+
SVector{Sz[1], T}(Tuple(x))
914
end
10-
_maybe_SArray(x, ::Val, ::Shaped1DAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
1115
_maybe_SArray(x, vals...) = x
1216

1317
@generated function static_getproperty(ca::ComponentVector, ::Val{s}) where {s}

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
77
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
88
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
99
InvertedIndices = "41ab1584-1d38-5bbf-9106-f11c6c58b48f"
10+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1011
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1112
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

test/jet_tests.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using JET
2+
using ComponentArrays
3+
using Test
4+
5+
@testset "JET type stability" begin
6+
# Test key ComponentArray operations for type stability
7+
ca = ComponentArray(a = 1.0, b = [2.0, 3.0], c = (d = 4.0, e = [5.0, 6.0]))
8+
9+
@testset "Core operations" begin
10+
# Test getdata type stability
11+
report = JET.report_opt(ComponentArrays.getdata, (typeof(ca),))
12+
@test length(JET.get_reports(report)) == 0
13+
14+
# Test getaxes type stability
15+
report = JET.report_opt(ComponentArrays.getaxes, (typeof(ca),))
16+
@test length(JET.get_reports(report)) == 0
17+
18+
# Test numeric indexing type stability
19+
report = JET.report_opt(getindex, (typeof(ca), Int))
20+
@test length(JET.get_reports(report)) == 0
21+
end
22+
23+
@testset "Construction" begin
24+
# Test ComponentArray construction from NamedTuple
25+
nt = (a = 1.0, b = [2.0, 3.0])
26+
report = JET.report_opt(ComponentArray, (typeof(nt),))
27+
# Note: some reports may come from Base, we only check ComponentArrays
28+
ca_reports = filter(r -> occursin("ComponentArrays", string(r)), JET.get_reports(report))
29+
@test length(ca_reports) == 0
30+
end
31+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,3 +982,7 @@ end
982982
@testset "Reactant" begin
983983
include("reactant_tests.jl")
984984
end
985+
986+
@testset "JET" begin
987+
include("jet_tests.jl")
988+
end

0 commit comments

Comments
 (0)