Skip to content

ReverseDiff tracks through length #418

@marcoct

Description

@marcoct

Note that this is an issue that occurs on branch of #417, and not master, because this code fails for a different reason that #417 fixes.

The following code

using Gen

@gen (static) function foo()
    @param b::Vector{Float32}
    n = length(b)
    x = zeros(n)
    a ~ normal(sum(x), 1.0)
    return nothing
end

@load_generated_functions()

init_parameter!((foo, :b), [0.0, 0.0])
trace = simulate(foo, ())
accumulate_param_gradients!(trace)

produces the error:

ERROR: LoadError: MethodError: no method matching zeros(::ReverseDiff.TrackedReal{Int64, Int64, Nothing})
Closest candidates are:
  zeros(::Union{Integer, AbstractUnitRange}...) at array.jl:498
  zeros(::Tuple{Vararg{Union{Integer, AbstractUnitRange}, N} where N}) at array.jl:500
  zeros(::Type{StaticArrays.MVector{N, T} where T}) where N at /home/marcoct/.julia/packages/StaticArrays/xV8rq/src/MVector.jl:25
  ...
Stacktrace:
 [1] (::var"#2#7")(n::ReverseDiff.TrackedReal{Int64, Int64, Nothing})
   @ Main ./none:0
 [2] macro expansion
   @ ~/.julia/packages/Gen/3mYgc/src/static_ir/backprop.jl:0 [inlined]
 [3] accumulate_param_gradients!(trace::var"##StaticIRTrace_foo#270", retval_grad::Nothing, scale_factor::Float64)
   @ Main ~/.julia/packages/Gen/3mYgc/src/static_ir/backprop.jl:549
 [4] accumulate_param_gradients!(trace::var"##StaticIRTrace_foo#270")
   @ Gen ~/.julia/packages/Gen/3mYgc/src/gen_fn_interface.jl:403
 [5] top-level scope
   @ ~/dev/GenExamples.jl/test/test.jl:15
in expression starting at /home/marcoct/dev/GenExamples.jl/test/test.jl:15

A careful redesign of how ReverseDiff is used for AD is probably needed. (ReverseDiff is currently being used as a stop-gap because it provides differentiation of arithmetic and linear algebra operations, and support for AD of new operations should be added by writing generative functions -- e.g. using https://www.gen.dev/dev/ref/extending/#Gen.CustomGradientGF -- instead of by extending ReverseDiff).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions