Skip to content

Commit c41e5d0

Browse files
Merge pull request #482 from ChrisRackauckas-Claude/statistics-weakdep
Move Statistics.jl to weak dependency extension
2 parents 9627135 + 4185d2a commit c41e5d0

File tree

6 files changed

+40
-13
lines changed

6 files changed

+40
-13
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1212
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
1313
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
14-
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1514
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
1615

1716
[weakdeps]
@@ -22,6 +21,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2221
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2322
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2423
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
24+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2525
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
2626
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2727
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
@@ -35,6 +35,7 @@ RecursiveArrayToolsMeasurementsExt = "Measurements"
3535
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
3636
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
3737
RecursiveArrayToolsSparseArraysExt = ["SparseArrays"]
38+
RecursiveArrayToolsStatisticsExt = "Statistics"
3839
RecursiveArrayToolsStructArraysExt = "StructArrays"
3940
RecursiveArrayToolsTablesExt = ["Tables"]
4041
RecursiveArrayToolsTrackerExt = "Tracker"
@@ -88,11 +89,12 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
8889
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
8990
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
9091
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
92+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
9193
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
9294
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9395
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
9496
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
9597
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
9698

9799
[targets]
98-
test = ["Aqua", "FastBroadcast", "ForwardDiff", "JET", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
100+
test = ["Aqua", "FastBroadcast", "ForwardDiff", "JET", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "Statistics", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module RecursiveArrayToolsStatisticsExt
2+
3+
using RecursiveArrayTools
4+
using Statistics
5+
6+
@inline Statistics.mean(VA::AbstractVectorOfArray; kwargs...) = mean(Array(VA); kwargs...)
7+
@inline function Statistics.median(VA::AbstractVectorOfArray; kwargs...)
8+
median(Array(VA); kwargs...)
9+
end
10+
@inline Statistics.std(VA::AbstractVectorOfArray; kwargs...) = std(Array(VA); kwargs...)
11+
@inline Statistics.var(VA::AbstractVectorOfArray; kwargs...) = var(Array(VA); kwargs...)
12+
@inline Statistics.cov(VA::AbstractVectorOfArray; kwargs...) = cov(Array(VA); kwargs...)
13+
@inline Statistics.cor(VA::AbstractVectorOfArray; kwargs...) = cor(Array(VA); kwargs...)
14+
15+
end

src/RecursiveArrayTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ $(DocStringExtensions.README)
55
module RecursiveArrayTools
66

77
using DocStringExtensions
8-
using RecipesBase, StaticArraysCore, Statistics,
8+
using RecipesBase, StaticArraysCore,
99
ArrayInterface, LinearAlgebra
1010
using SymbolicIndexingInterface
1111

src/array_partition.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,17 @@ function recursivecopy!(A::ArrayPartition{T, S},
342342
return A
343343
end
344344

345-
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
345+
function recursive_mean(A::ArrayPartition)
346+
n = npartitions(A)
347+
if n == 0
348+
return zero(eltype(A))
349+
end
350+
total = recursive_mean(A.x[1])
351+
for i in 2:n
352+
total += recursive_mean(A.x[i])
353+
end
354+
return total / n
355+
end
346356

347357
# note: consider only first partition for recursive one and eltype
348358
recursive_one(A::ArrayPartition) = recursive_one(first(A.x))

src/utils.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ end
293293
recursive_unitless_eltype(a::Type{T}) where {T <: Number} = typeof(one(eltype(a)))
294294
recursive_unitless_eltype(::Type{<:Enum{T}}) where {T} = T
295295

296-
recursive_mean(x...) = mean(x...)
297296
function recursive_mean(vecvec::Vector{T}) where {T <: AbstractArray}
298297
out = zero(vecvec[1])
299298
for i in eachindex(vecvec)
@@ -302,6 +301,15 @@ function recursive_mean(vecvec::Vector{T}) where {T <: AbstractArray}
302301
out / length(vecvec)
303302
end
304303

304+
# Fallback for scalars and general cases without Statistics
305+
function recursive_mean(x::AbstractArray{T}) where {T <: Number}
306+
sum(x) / length(x)
307+
end
308+
309+
function recursive_mean(x::Number)
310+
x
311+
end
312+
305313
# From Iterators.jl. Moved here since Iterators.jl is not precompile safe anymore.
306314

307315
# Concatenate the output of n iterators

src/vector_of_array.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,14 +1153,6 @@ end
11531153
mapreduce(f, Base.mul_prod, VA; kwargs...)
11541154
end
11551155

1156-
@inline Statistics.mean(VA::AbstractVectorOfArray; kwargs...) = mean(Array(VA); kwargs...)
1157-
@inline function Statistics.median(VA::AbstractVectorOfArray; kwargs...)
1158-
median(Array(VA); kwargs...)
1159-
end
1160-
@inline Statistics.std(VA::AbstractVectorOfArray; kwargs...) = std(Array(VA); kwargs...)
1161-
@inline Statistics.var(VA::AbstractVectorOfArray; kwargs...) = var(Array(VA); kwargs...)
1162-
@inline Statistics.cov(VA::AbstractVectorOfArray; kwargs...) = cov(Array(VA); kwargs...)
1163-
@inline Statistics.cor(VA::AbstractVectorOfArray; kwargs...) = cor(Array(VA); kwargs...)
11641156
@inline Base.adjoint(VA::AbstractVectorOfArray) = Adjoint(VA)
11651157

11661158
# linear algebra

0 commit comments

Comments
 (0)