Hello,
I have been for educational purposes implementing RNN by hand and wanted to be fancy and use accumulate instead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple.
A have carved out an MWE, which would look like this
using Zygote
x = [randn(Float32, 2) for i in 1:3]
h = randn(Float32, 2)
function f(α, h, x)
o = accumulate(x, init = h) do h, x
α * h + x
end
end
function g(α, h, x)
o = accumulate(x, init = (h, x[1])) do (h,_),x
(α * h + x, x)
end
first.(o)
end
gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]
gradient(α -> sum(sum(f(α, h, x))), 1f0)[1]
While computing gradient of f succeeds, computing gradient of g crashes with
julia> gradient(α -> sum(sum(g(α, h, x))), 1f0)[1]
ERROR: MethodError: no method matching construct(::Type{Any}, ::Tuple{FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}, ChainRulesCore.NoTangent})
Closest candidates are:
construct(::Type{T}, ::T) where T<:Tuple
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:251
construct(::Type{T}, ::NamedTuple{L}) where {T, L}
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_types/structural_tangent.jl:235
Stacktrace:
[1] +(a::ChainRulesCore.Tangent{Tuple{…}, Tuple{…}}, d::ChainRulesCore.Tangent{Any, Tuple{…}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/6DiyF/src/tangent_arithmetic.jl:142
[2] (::ChainRules.var"#1699#1702")(::Tuple{…}, ::Tuple{…})
@ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:541
[3] iterate(itr::Base.Iterators.Accumulate)
@ Base.Iterators ./iterators.jl:589 [inlined]
[4] collect_to!
@ ./array.jl:892 [inlined]
[5] collect_to_with_first!
@ ./array.jl:870 [inlined]
[6] _collect(c::Any, itr::Any, ::Base.EltypeUnknown, isz::Union{Base.HasLength, Base.HasShape})
@ Base ./array.jl:864 [inlined]
[7] collect(itr::Base.Generator)
@ Base ./array.jl:759 [inlined]
[8] #accumulate#893
@ ./accumulate.jl:281 [inlined]
[9] accumulate
@ ./accumulate.jl:278 [inlined]
[10] (::ChainRules.var"#decumulate#1701"{…})(dy::Vector{…})
@ ChainRules ~/.julia/packages/ChainRules/FLsQJ/src/rulesets/Base/mapreduce.jl:540
[11] ZBack
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:211 [inlined]
[12] (::Zygote.var"#kw_zpullback#53"{ChainRules.var"#decumulate#1701"{…}})(dy::Vector{Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/chainrules.jl:237
[13] g
@ ./REPL[43]:2 [inlined]
[14] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::FillArrays.Fill{FillArrays.Fill{…}, 1, Tuple{…}})
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[15] #53
@ ./REPL[44]:1 [inlined]
[16] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface2.jl:0
[17] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:91
[18] gradient(f::Function, args::Float32)
@ Zygote ~/.julia/packages/Zygote/jxHJc/src/compiler/interface.jl:148
[19] top-level scope
@ REPL[44]:1
Some type information was truncated. Use `show(err)` to see complete types.
Julia and environment
julia> versioninfo()
Julia Version 1.10.0-rc2
Commit dbb9c46795b (2023-12-03 15:25 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: macOS (x86_64-apple-darwin22.4.0)
CPU: 8 × Intel(R) Core(TM) i5-8279U CPU @ 2.40GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 1 on 8 virtual cores
(tmp) pkg> st
Status `/private/tmp/Project.toml`
[082447d4] ChainRules v1.63.0
[d360d2e6] ChainRulesCore v1.21.1
[26cc04aa] FiniteDifferences v0.12.31
[587475ba] Flux v0.14.11
[3bd65402] Optimisers v0.3.2
[eeda0dda] SafeTensors v1.0.0
[2913bbd2] StatsBase v0.34.2
[e88e6eb3] Zygote v0.6.69
Thanks for help
Hello,
I have been for educational purposes implementing RNN by hand and wanted to be fancy and use
accumulateinstead of recursion or for rule. But I run into an error, when one of the operands in accumulate is tuple.A have carved out an MWE, which would look like this
While computing gradient of
fsucceeds, computing gradient ofgcrashes withJulia and environment
Thanks for help