Skip to content

stack overflow issue  #809

@jakubMitura14

Description

@jakubMitura14

Hello
I have a memory-constrained problem with a Lux.jl model that uses Zygote for most of the backpropagation.

I tried to approach this from chainrules perspective I need to checkpoint each Lux.jl layer in neural network. So I tried to achieve it like that :

function ChainRulesCore.rrule(::typeof(Lux.apply), l::Lux.AbstractExplicitLayer, x, ps, st)
    y = Lux.apply(l, x, ps, st)
    
    function pullback_checkpointed(Δy)
        y, pb =Zygote.pullback(Lux.apply,l, x, ps, st) 
        return NoTangent(), pb(Δy)
    end
    
    y, pullback_checkpointed
end

Rule gets invoked in backpropagation Hovewer the issue is that for some reason it try recursively to do backpropagation of the first line

 y = Lux.apply(l, x, ps, st)

so I get stack overflow error; how to correct it?

I had also posted this issue in https://discourse.julialang.org/t/avoid-storing-intermediate-results-from-the-forward-pass-by-default/119694/4?u=jakub_mitura

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions