Skip to content

indiscriminately clearing checkpoints with SingleMemoryStorageSchedule corrupts the adjoints #211

@sghelichkhani

Description

@sghelichkhani

Second part of an issue we have encountered g-adopt/g-adopt#237, consider the case of this reproducer. When using no scheduler by:

python minimal_clearing_cache.py  none

we get:

using scheduler: NoneType
        J1: 12.520709541433629
        dJdm 1: 3205.301642607008

But when using a scheduler by:

python minimal_clearing_cache.py memory

we get:

using scheduler: SingleMemoryStorageSchedule
        J1: 12.520709541433629
        dJdm 1: 73765.11424588156

Cause

SingleMemoryStorageSchedule explicitly deletes every checkpoint that did not appear in the previous step’s adjoint_dependencies:

if isinstance(self._schedule, SingleMemoryStorageSchedule):
if step > 1 and var not in self.tape.timesteps[step - 1].adjoint_dependencies:
var._checkpoint = None

If a variable is part of a long-range dependency (e.g. is used every third step), it can be missing from the immediately-preceding dependency set even though it is still required later. The checkpoint is therefore cleared and the reverse pass reconstructs an incorrect value, giving a wrong gradient.

Considering the comment saying “Handle the case for SingleMemoryStorageSchedule”, I would have thought there is a reason for this, but no explanation is given.

Ideas (!?)

  • Remove the special-case clearing and rely on the revised dependency machinery to decide what can be freed, or

  • Only clear if the variable is guaranteed not to re-appear later (e.g. by inspecting timesteps[step:].adjoint_dependencies) ?

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions