Skip to content

Commit 3f08d93

Browse files
Fix recursive_mean to work without Statistics dependency
The move of Statistics.jl to a weak dependency caused recursive_mean to fail because it used Statistics.mean internally. This fix: - Removes the generic fallback `recursive_mean(x...) = mean(x...)` from utils.jl - Adds explicit implementations for scalars and arrays of numbers that compute the mean without requiring Statistics - Rewrites recursive_mean(A::ArrayPartition) to avoid using Statistics.mean 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 15b9ad8 commit 3f08d93

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/array_partition.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,17 @@ function recursivecopy!(A::ArrayPartition, B::ArrayPartition)
291291
end
292292
recursivecopy(A::ArrayPartition) = ArrayPartition(copy.(A.x))
293293

294-
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
294+
function recursive_mean(A::ArrayPartition)
295+
n = npartitions(A)
296+
if n == 0
297+
return zero(eltype(A))
298+
end
299+
total = recursive_mean(A.x[1])
300+
for i in 2:n
301+
total += recursive_mean(A.x[i])
302+
end
303+
return total / n
304+
end
295305

296306
# note: consider only first partition for recursive one and eltype
297307
recursive_one(A::ArrayPartition) = recursive_one(first(A.x))

src/utils.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ end
291291
recursive_unitless_eltype(a::Type{T}) where {T <: Number} = typeof(one(eltype(a)))
292292
recursive_unitless_eltype(::Type{<:Enum{T}}) where {T} = T
293293

294-
recursive_mean(x...) = mean(x...)
295294
function recursive_mean(vecvec::Vector{T}) where {T <: AbstractArray}
296295
out = zero(vecvec[1])
297296
for i in eachindex(vecvec)
@@ -300,6 +299,15 @@ function recursive_mean(vecvec::Vector{T}) where {T <: AbstractArray}
300299
out / length(vecvec)
301300
end
302301

302+
# Fallback for scalars and general cases without Statistics
303+
function recursive_mean(x::AbstractArray{T}) where {T <: Number}
304+
sum(x) / length(x)
305+
end
306+
307+
function recursive_mean(x::Number)
308+
x
309+
end
310+
303311
# From Iterators.jl. Moved here since Iterators.jl is not precompile safe anymore.
304312

305313
# Concatenate the output of n iterators

0 commit comments

Comments
 (0)