Skip to content

Commit f45ed78

Browse files
committed
move remat to fwd
1 parent 33fefd9 commit f45ed78

1 file changed

Lines changed: 22 additions & 44 deletions

File tree

src/maxtext/utils/pipeline_utils.py

Lines changed: 22 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -143,70 +143,48 @@ def run_scanned_custom(loop_state, positions, segment_ids):
143143
return final_state
144144

145145
def run_scanned_custom_fwd(loop_state, positions, segment_ids):
146-
final_state, _ = run_scanned(model, loop_state)
147-
# We return loop_state as residual. model is passed to bwd as arg.
148-
return final_state, (
149-
loop_state,
150-
positions,
151-
segment_ids,
152-
)
153-
154-
def run_scanned_custom_bwd(residuals, g_final_state):
155-
init_loop_state, positions, segment_ids = residuals
156-
157-
# Re-run forward pass to get saved states (checkpointing)
158-
def scan_body_fwd(carry, _):
159-
new_state = model.run_one_iteration(
160-
carry,
146+
def step_fn(s):
147+
out = model.run_one_iteration(
148+
s,
161149
positions,
162150
segment_ids,
163151
deterministic,
164152
model_mode,
165153
logical_partition_spec=logical_partition_spec,
166154
)
167-
# Return lightweight state for saving (exclude bsw/weights)
168-
saved = {k: v for k, v in carry.items() if k not in ["bsw", "weights"]}
169-
return new_state, saved
155+
return out
170156

171-
_, saved_states = jax.lax.scan(
157+
step_fn_remat = jax.remat(step_fn, prevent_cse=False, policy=model.get_pipeline_remat_policy())
158+
159+
# Forward pass generating VJP functions and intermediate state
160+
def scan_body_fwd(carry, _):
161+
new_state, vjp_fun = jax.vjp(step_fn_remat, carry)
162+
return new_state, vjp_fun
163+
164+
final_state, vjp_funs = jax.lax.scan(
172165
scan_body_fwd,
173-
init_loop_state,
166+
loop_state,
174167
None,
175168
length=length,
176169
)
177170

171+
# We return vjp_funs as residual. model is passed to bwd as arg.
172+
return final_state, vjp_funs
173+
174+
def run_scanned_custom_bwd(residuals, g_final_state):
175+
vjp_funs = residuals
176+
178177
# Backward scan to accumulate gradients
179-
def scan_body_bwd(carry, saved_slice):
178+
def scan_body_bwd(carry, vjp_fun):
180179
d_next_state = carry
181180

182-
# Reconstruct current loop_state (input to step)
183-
curr_loop_state = {
184-
**saved_slice,
185-
"bsw": init_loop_state["bsw"],
186-
"weights": init_loop_state["weights"],
187-
}
188-
189-
# Define function to differentiate w.r.t loop_state
190-
def step_fn(s):
191-
out = model.run_one_iteration(
192-
s,
193-
positions,
194-
segment_ids,
195-
deterministic,
196-
model_mode,
197-
logical_partition_spec=logical_partition_spec,
198-
)
199-
return out
200-
201-
_, vjp_fun = jax.vjp(step_fn, curr_loop_state)
202-
203-
# Backprop d_next_state
181+
# Backprop d_next_state using saved vjp_fun
204182
(d_curr_state,) = vjp_fun(d_next_state)
205183

206184
return d_curr_state, None
207185

208186
# Run backward scan
209-
d_init_state, _ = jax.lax.scan(scan_body_bwd, g_final_state, saved_states, reverse=True)
187+
d_init_state, _ = jax.lax.scan(scan_body_bwd, g_final_state, vjp_funs, reverse=True)
210188

211189
return (d_init_state, None, None)
212190

0 commit comments

Comments
 (0)