Skip to content

Type instability in GroupNormL2 on julia pre #152

@MaxenceGollier

Description

@MaxenceGollier

JET blames these lines of code. (There where other instabilities detected in the prox!), I think using zip is not stable and should be avoided.

for (idx, λ) zip(f.idx, f.lambda)
@views sum_c += λ * norm(x[idx])
end

Here is a MWE:

using LinearAlgebra

using ProximalOperators
using ShiftedProximalOperators

h = NormL2(1.0)
n = 1000
xk = rand(1000)
y = rand(1000)

ψ = shifted(h, xk)

println(@report_opt(ψ(y)))
println(@report_opt(prox!(y, ψ, y, 1.0)))

returns

═════ 2 possible errors found ═════
┌ (::ShiftedGroupNormL2{Float64, Vector{…}, Vector{…}, Vector{…}, Vector{…}, Vector{…}})(y::Vector{Float64}) @ ShiftedProximalOperators C:\Users\mgoll\ShiftedProximalOperators.jl-1\src\ShiftedProximalOperators.jl:53
│┌ (::GroupNormL2{Float64, Vector{Float64}, Vector{Colon}})(x::Vector{Float64}) @ ShiftedProximalOperators C:\Users\mgoll\ShiftedProximalOperators.jl-1\src\groupNormL2.jl:35  
││┌ iterate(z::Base.Iterators.Zip{Tuple{Vector{Colon}, Vector{Float64}}}) @ Base.Iterators ./iterators.jl:415
│││┌ _zip_iterate_all(is::Tuple{Vector{Colon}, Vector{Float64}}, ss::Tuple{Tuple{}, Tuple{}}) @ Base.Iterators ./iterators.jl:429
││││ runtime dispatch detected: Base.Iterators._zip_iterate_interleave(%7::Tuple{Union{Tuple{Colon, Int64}, Tuple{Float64, Int64}}, Tuple{Float64, Int64}}, (), (missing, missing))::Tuple{Tuple{Union{…}, Union{…}}, Tuple{Union{…}, Union{…}}}
│││└────────────────────
│┌ (::GroupNormL2{Float64, Vector{Float64}, Vector{Colon}})(x::Vector{Float64}) @ ShiftedProximalOperators C:\Users\mgoll\ShiftedProximalOperators.jl-1\src\groupNormL2.jl:37  
││ runtime dispatch detected: iterate(%3::Base.Iterators.Zip{Tuple{Vector{Colon}, Vector{Float64}}}, %13::Tuple{Union{Colon, Float64, Int64}, Union{Float64, Int64}})::Union{Nothing, Tuple{Tuple{Union{…}, Union{…}}, Tuple{Union{…}, Union{…}}}}
│└────────────────────

═════ 14 possible errors found ═════
┌ prox!(y::Vector{…}, ψ::ShiftedGroupNormL2{…}, q::Vector{…}, σ::Float64) @ ShiftedProximalOperators C:\Users\mgoll\ShiftedProximalOperators.jl-1\src\shiftedGroupNormL2.jl:66 
│┌ iterate(z::Base.Iterators.Zip{Tuple{Vector{Colon}, Vector{Float64}}}) @ Base.Iterators ./iterators.jl:415
││┌ _zip_iterate_all(is::Tuple{Vector{Colon}, Vector{Float64}}, ss::Tuple{Tuple{}, Tuple{}}) @ Base.Iterators ./iterators.jl:429
│││ runtime dispatch detected: Base.Iterators._zip_iterate_interleave(%7::Tuple{Union{Tuple{Colon, Int64}, Tuple{Float64, Int64}}, Tuple{Float64, Int64}}, (), (missing, missing))::Tuple{Tuple{Union{…}, Union{…}}, Tuple{Union{…}, Union{…}}}
││└────────────────────
┌ prox!(y::Vector{…}, ψ::ShiftedGroupNormL2{…}, q::Vector{…}, σ::Float64) @ ShiftedProximalOperators C:\Users\mgoll\ShiftedProximalOperators.jl-1\src\shiftedGroupNormL2.jl:74 
│┌ materialize!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:901
││┌ materialize!(::Base.Broadcast.DefaultArrayStyle{…}, dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:904
│││┌ copyto!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:946
││││┌ copyto!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:990
│││││┌ preprocess(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:973
││││││ runtime dispatch detected: Base.Broadcast.Broadcasted(nothing, *, %2::Tuple{Union{Float64, Base.Broadcast.Extruded{…}}, Base.Broadcast.Extruded{SubArray{…}, Tuple{…}, Tuple{…}}}, %3::Tuple{Base.OneTo{Int64}})::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{…}}, typeof(*), <:Tuple{Union{…}, Base.Broadcast.Extruded{…}}}
│││││└────────────────────
│┌ materialize!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:901
││┌ materialize!(::Base.Broadcast.DefaultArrayStyle{…}, dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:904
│││┌ copyto!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:946
││││┌ copyto!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:990
│││││┌ preprocess(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:973
││││││ runtime dispatch detected: Base.Broadcast.Broadcasted(nothing, *, %2::Tuple{Union{Float64, Base.Broadcast.Extruded{…}}, Base.Broadcast.Extruded{SubArray{…}, Tuple{…}, Tuple{…}}}, ())::Base.Broadcast.Broadcasted{Nothing, Tuple{}, typeof(*), <:Tuple{Union{…}, Base.Broadcast.Extruded{…}}}
│││││└────────────────────
│┌ materialize!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:901
││┌ materialize!(::Base.Broadcast.DefaultArrayStyle{…}, dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:904
│││┌ copyto!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:957
││││┌ copyto!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:990
│││││┌ preprocess(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:973
││││││ runtime dispatch detected: Base.Broadcast.Broadcasted(nothing, *, %2::Tuple{Union{Float64, Base.Broadcast.Extruded{…}}, Base.Broadcast.Extruded{SubArray{…}, Tuple{}, Tuple{}}}, %3::Tuple{Base.OneTo{Int64}})::Base.Broadcast.Broadcasted{Nothing, Tuple{Base.OneTo{…}}, typeof(*), <:Tuple{Union{…}, Base.Broadcast.Extruded{…}}}
│││││└────────────────────
│┌ materialize!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:901
││┌ materialize!(::Base.Broadcast.DefaultArrayStyle{…}, dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:904
│││┌ copyto!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:957
││││┌ copyto!(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:990
│││││┌ preprocess(dest::SubArray{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:973
││││││ runtime dispatch detected: Base.Broadcast.Broadcasted(nothing, *, %2::Tuple{Union{Float64, Base.Broadcast.Extruded{…}}, Base.Broadcast.Extruded{SubArray{…}, Tuple{}, Tuple{}}}, ())::Base.Broadcast.Broadcasted{Nothing, Tuple{}, typeof(*), <:Tuple{Union{…}, Base.Broadcast.Extruded{…}}}
│││││└────────────────────
┌ prox!(y::Vector{…}, ψ::ShiftedGroupNormL2{…}, q::Vector{…}, σ::Float64) @ ShiftedProximalOperators C:\Users\mgoll\ShiftedProximalOperators.jl-1\src\shiftedGroupNormL2.jl:77 
│┌ materialize!(dest::Vector{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:901
││┌ materialize!(::Base.Broadcast.DefaultArrayStyle{…}, dest::Vector{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:904
│││┌ copyto!(dest::Vector{…}, bc::Base.Broadcast.Broadcasted{…}) @ Base.Broadcast ./broadcast.jl:946
││││┌ copyto!(dest::Vector{Float64}, bc::Base.Broadcast.Broadcasted{Nothing, Tuple{…}, typeof(-), Tuple{…}}) @ Base.Broadcast ./broadcast.jl:990
│││││┌ preprocess(dest::Vector{Float64}, bc::Base.Broadcast.Broadcasted{Nothing, Tuple{…}, typeof(-), Tuple{…}}) @ Base.Broadcast ./broadcast.jl:973
││││││┌ preprocess_args(dest::Vector{Float64}, args::Tuple{Vector{…}, Base.Broadcast.Broadcasted{…}}) @ Base.Broadcast ./broadcast.jl:976
│││││││┌ preprocess(dest::Array, bc::Base.Broadcast.Broadcasted) @ Base.Broadcast ./broadcast.jl:973
││││││││┌ Base.Broadcast.Broadcasted(style::Nothing, f::Any, args::Tuple, axes::Any) @ Base.Broadcast ./broadcast.jl:178
│││││││││┌ convert(::Type{Args} where Args<:Tuple, x::Tuple) @ Base ./essentials.jl:659
││││││││││┌ Val(x::Int64) @ Base ./essentials.jl:1085
│││││││││││ runtime dispatch detected: %1::Type{Val{_A}} where _A()::Val
││││││││││└────────────────────
│││││││┌ preprocess_args(dest::Array, args::Tuple{Any}) @ Base.Broadcast ./broadcast.jl:977
││││││││┌ preprocess(dest::Array, x::Any) @ Base.Broadcast ./broadcast.jl:974
│││││││││┌ broadcast_unalias(dest::Array, src::Any) @ Base.Broadcast ./broadcast.jl:967
││││││││││┌ unalias(dest::Array, A::AbstractArray) @ Base ./abstractarray.jl:1548
│││││││││││┌ mightalias(A::Array, B::AbstractArray) @ Base ./abstractarray.jl:1585
││││││││││││┌ _isdisjoint(as::Tuple{UInt64}, bs::Tuple) @ Base ./abstractarray.jl:1593
│││││││││││││┌ in(x::UInt64, itr::Tuple) @ Base ./operators.jl:1383
││││││││││││││┌ _in_tuple(x::UInt64, itr::Tuple, anymissing::Bool) @ Base ./operators.jl:1399
│││││││││││││││┌ _in_tuple(x::UInt64, itr::Tuple, anymissing::Bool) @ Base ./operators.jl:1393
││││││││││││││││ runtime dispatch detected: (%11::Any == x::UInt64)::Any
│││││││││││││││└────────────────────
│││││││││││││││┌ _in_tuple(x::UInt64, itr::Tuple, anymissing::Bool) @ Base ./operators.jl:1394
││││││││││││││││ runtime dispatch detected: ismissing(%12::Any)::Bool
│││││││││││││││└────────────────────
│││││││││││││││┌ _in_tuple(x::UInt64, itr::Tuple, anymissing::Bool) @ Base ./operators.jl:1399
││││││││││││││││ runtime dispatch detected: Base.tail(itr::Tuple)::Tuple
│││││││││││││││└────────────────────
││││││││││││││┌ _in_tuple(x::UInt64, itr::Tuple, anymissing::Bool) @ Base ./operators.jl:1393
│││││││││││││││ runtime dispatch detected: (%10::Any == x::UInt64)::Any
││││││││││││││└────────────────────
││││││││││││││┌ _in_tuple(x::UInt64, itr::Tuple, anymissing::Bool) @ Base ./operators.jl:1394
│││││││││││││││ runtime dispatch detected: ismissing(%11::Any)::Bool
││││││││││││││└────────────────────
││││││┌ preprocess_args(dest::Vector{Float64}, args::Tuple{Vector{…}, Base.Broadcast.Broadcasted{…}}) @ Base.Broadcast ./broadcast.jl:976
│││││││ failed to optimize due to recursion: Base.Broadcast.preprocess_args(::Vector{Float64}, ::Tuple{Vector{…}, Base.Broadcast.Broadcasted{…}})
││││││└────────────────────
│││││┌ preprocess(dest::Vector{Float64}, bc::Base.Broadcast.Broadcasted{Nothing, Tuple{…}, typeof(-), Tuple{…}}) @ Base.Broadcast ./broadcast.jl:973
││││││ failed to optimize due to recursion: Base.Broadcast.preprocess(::Vector{Float64}, ::Base.Broadcast.Broadcasted{Nothing, Tuple{…}, typeof(-), Tuple{…}})
│││││└────────────────────
┌ prox!(y::Vector{…}, ψ::ShiftedGroupNormL2{…}, q::Vector{…}, σ::Float64) @ ShiftedProximalOperators C:\Users\mgoll\ShiftedProximalOperators.jl-1\src\shiftedGroupNormL2.jl:76 
│ runtime dispatch detected: iterate(%23::Base.Iterators.Zip{Tuple{Vector{Colon}, Vector{Float64}}}, %32::Tuple{Union{Colon, Float64, Int64}, Union{Float64, Int64}})::Union{Nothing, Tuple{Tuple{Union{…}, Union{…}}, Tuple{Union{…}, Union{…}}}}

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions