Skip to content

Commit 5cf5c32

Browse files
committed
update custom vjp
1 parent f45ed78 commit 5cf5c32

1 file changed

Lines changed: 39 additions & 14 deletions

File tree

src/maxtext/utils/pipeline_utils.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def run_scanned_custom(loop_state, positions, segment_ids):
143143
return final_state
144144

145145
def run_scanned_custom_fwd(loop_state, positions, segment_ids):
146-
def step_fn(s):
146+
def step_fn(lightweight_state, bsw, weights):
147+
# Reconstruct full state for the model
148+
s = {**lightweight_state, "bsw": bsw, "weights": weights}
147149
out = model.run_one_iteration(
148150
s,
149151
positions,
@@ -152,40 +154,63 @@ def step_fn(s):
152154
model_mode,
153155
logical_partition_spec=logical_partition_spec,
154156
)
155-
return out
157+
# Deconstruct back to lightweight to decouple gradients
158+
new_lightweight = {k: v for k, v in out.items() if k not in ["bsw", "weights"]}
159+
return new_lightweight, out["bsw"], out["weights"]
156160

157161
step_fn_remat = jax.remat(step_fn, prevent_cse=False, policy=model.get_pipeline_remat_policy())
158162

159-
# Forward pass generating VJP functions and intermediate state
163+
# Separate heavy and light state initially
164+
initial_lightweight = {k: v for k, v in loop_state.items() if k not in ["bsw", "weights"]}
165+
initial_bsw = loop_state["bsw"]
166+
initial_weights = loop_state["weights"]
167+
160168
def scan_body_fwd(carry, _):
161-
new_state, vjp_fun = jax.vjp(step_fn_remat, carry)
162-
return new_state, vjp_fun
169+
lightweight_carry, bsw_carry, weights_carry = carry
170+
(new_lightweight, new_bsw, new_weights), vjp_fun = jax.vjp(
171+
step_fn_remat, lightweight_carry, bsw_carry, weights_carry
172+
)
173+
return (new_lightweight, new_bsw, new_weights), vjp_fun
163174

164-
final_state, vjp_funs = jax.lax.scan(
175+
(final_lightweight, final_bsw, final_weights), vjp_funs = jax.lax.scan(
165176
scan_body_fwd,
166-
loop_state,
177+
(initial_lightweight, initial_bsw, initial_weights),
167178
None,
168179
length=length,
169180
)
170181

171-
# We return vjp_funs as residual. model is passed to bwd as arg.
182+
final_state = {**final_lightweight, "bsw": final_bsw, "weights": final_weights}
172183
return final_state, vjp_funs
173184

174185
def run_scanned_custom_bwd(residuals, g_final_state):
175186
vjp_funs = residuals
176187

177-
# Backward scan to accumulate gradients
188+
# Split the gradient of the final state
189+
g_lightweight = {k: v for k, v in g_final_state.items() if k not in ["bsw", "weights"]}
190+
g_bsw = g_final_state["bsw"]
191+
g_weights = g_final_state["weights"]
192+
178193
def scan_body_bwd(carry, vjp_fun):
179-
d_next_state = carry
194+
d_next_lightweight, d_next_bsw, d_next_weights = carry
180195

181-
# Backprop d_next_state using saved vjp_fun
182-
(d_curr_state,) = vjp_fun(d_next_state)
196+
# Apply saved vjp_fun directly
197+
d_curr_lightweight, d_curr_bsw, d_curr_weights = vjp_fun((d_next_lightweight, d_next_bsw, d_next_weights))
183198

184-
return d_curr_state, None
199+
# Accumulate gradients for invariant parts
200+
d_bsw_accum = jax.tree.map(lambda x, y: x + y, d_next_bsw, d_curr_bsw)
201+
d_weights_accum = jax.tree.map(lambda x, y: x + y, d_next_weights, d_curr_weights)
202+
203+
return (d_curr_lightweight, d_bsw_accum, d_weights_accum), None
185204

186205
# Run backward scan
187-
d_init_state, _ = jax.lax.scan(scan_body_bwd, g_final_state, vjp_funs, reverse=True)
206+
(d_init_lightweight, d_init_bsw, d_init_weights), _ = jax.lax.scan(
207+
scan_body_bwd,
208+
(g_lightweight, g_bsw, g_weights),
209+
vjp_funs,
210+
reverse=True,
211+
)
188212

213+
d_init_state = {**d_init_lightweight, "bsw": d_init_bsw, "weights": d_init_weights}
189214
return (d_init_state, None, None)
190215

191216
run_scanned_custom.defvjp(run_scanned_custom_fwd, run_scanned_custom_bwd)

0 commit comments

Comments
 (0)