Skip to content

Commit 4185d2a

Browse files
Fix recursive_mean to work without Statistics dependency
The move of Statistics.jl to a weak dependency caused recursive_mean to fail because it used Statistics.mean internally. This fix: - Removes the generic fallback `recursive_mean(x...) = mean(x...)` from utils.jl - Adds explicit implementations for scalars and arrays of numbers that compute the mean without requiring Statistics - Rewrites recursive_mean(A::ArrayPartition) to avoid using Statistics.mean 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 114bb9f commit 4185d2a

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/array_partition.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,17 @@ function recursivecopy!(A::ArrayPartition{T, S},
342342
return A
343343
end
344344

345-
recursive_mean(A::ArrayPartition) = mean((recursive_mean(x) for x in A.x))
345+
function recursive_mean(A::ArrayPartition)
346+
n = npartitions(A)
347+
if n == 0
348+
return zero(eltype(A))
349+
end
350+
total = recursive_mean(A.x[1])
351+
for i in 2:n
352+
total += recursive_mean(A.x[i])
353+
end
354+
return total / n
355+
end
346356

347357
# note: consider only first partition for recursive one and eltype
348358
recursive_one(A::ArrayPartition) = recursive_one(first(A.x))

src/utils.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ end
293293
recursive_unitless_eltype(a::Type{T}) where {T <: Number} = typeof(one(eltype(a)))
294294
recursive_unitless_eltype(::Type{<:Enum{T}}) where {T} = T
295295

296-
recursive_mean(x...) = mean(x...)
297296
function recursive_mean(vecvec::Vector{T}) where {T <: AbstractArray}
298297
out = zero(vecvec[1])
299298
for i in eachindex(vecvec)
@@ -302,6 +301,15 @@ function recursive_mean(vecvec::Vector{T}) where {T <: AbstractArray}
302301
out / length(vecvec)
303302
end
304303

304+
# Fallback for scalars and general cases without Statistics
305+
function recursive_mean(x::AbstractArray{T}) where {T <: Number}
306+
sum(x) / length(x)
307+
end
308+
309+
function recursive_mean(x::Number)
310+
x
311+
end
312+
305313
# From Iterators.jl. Moved here since Iterators.jl is not precompile safe anymore.
306314

307315
# Concatenate the output of n iterators

0 commit comments

Comments
 (0)