Skip to content

BUG: incorrect adjoint when using checkpoint scheduler #5082

@stephankramer

Description

@stephankramer

Consider the follow code:

from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import SingleMemoryStorageSchedule
continue_annotation()
tape = get_working_tape()

schedule = SingleMemoryStorageSchedule()

tape.enable_checkpointing(schedule)

mesh = UnitSquareMesh(1,1)
V = FunctionSpace(mesh, "CG", 1)

m = Function(V).assign(1.0)

sumf = Function(V)
u = Function(V)
tst = TestFunction(V)
F = tst*u*dx - tst*m*m*dx

problem = NonlinearVariationalProblem(F, u)
solver = NonlinearVariationalSolver(problem)

for step in tape.timestepper(iter(range(4))):
    solver.solve()
    sumf.assign(sumf + u)

J = assemble(sumf*dx)
rf = ReducedFunctional(J, Control(m))

m0 = Function(V).assign(2.0)
print(rf(m0))
print(rf.derivative(apply_riesz=True).dat.data)

The taped model computes $$J=4m^2$$ and thus the derivative should be $$\frac{dJ}{dm} = 8m$$ and therefore for m=2 I expect the derivative to be 16. With the checkpointing schedule as above however, even though the forward (re)run with m=2 correctly produces $$J=16$$, the derivative it computes is 8 - it seems during the adjoint evaluation it reuses the old control value not the current one. If I comment out the enable_checkpointing line, I do get the expected result.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions