Skip to content

Commit da37ebc

Browse files
committed
remove weights from most inputs
1 parent 721bb5a commit da37ebc

3 files changed

Lines changed: 34 additions & 36 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ xprof_e2e_enable_fw_power_level_event: False
923923
xprof_e2e_enable_fw_thermal_event: False
924924
profile_power_events: False # Set to True to enable TPU-specific power/thermal profiling events. Defaults to False to avoid breaking GPU xplane tracing.
925925

926-
log_config: True # Prints the config (after defaults have been set by pyconfig logic)
926+
log_config: False # Prints the config (after defaults have been set by pyconfig logic)
927927
debug_sharding: False # Prints model weights sharding info
928928

929929
# Checkpoint Structured logging

src/maxtext/layers/pipeline.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,9 +1252,7 @@ def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12521252
bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec)
12531253
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
12541254

1255-
def run_one_iteration(
1256-
self, loop_state, bsw, weights, positions, segment_ids, deterministic, model_mode, logical_partition_spec
1257-
):
1255+
def run_one_iteration(self, loop_state, bsw, positions, segment_ids, deterministic, model_mode, logical_partition_spec):
12581256
"""Executes the forward/backward logic for a single microbatch inside the pipeline.
12591257
12601258
This acts as the core step function that our `jax.lax.scan` wrappers call. It routes
@@ -1353,12 +1351,11 @@ def __call__(
13531351

13541352
logical_partition_spec = pipeline_utils.strip_pipeline_repeat_logical_axis(logical_partition_spec)
13551353

1356-
def run_iteration_scannable(model, loop_state, bsw, weights):
1354+
def run_iteration_scannable(model, loop_state, bsw):
13571355
return (
13581356
model.run_one_iteration(
13591357
loop_state,
13601358
bsw,
1361-
weights,
13621359
positions,
13631360
segment_ids,
13641361
deterministic,
@@ -1386,6 +1383,7 @@ def run_iteration_scannable(model, loop_state, bsw, weights):
13861383
physical_partition_spec=physical_partition_spec,
13871384
positions=positions,
13881385
segment_ids=segment_ids,
1386+
pipeline_weights=weights,
13891387
)
13901388

13911389
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
@@ -1396,13 +1394,13 @@ def run_iteration_scannable(model, loop_state, bsw, weights):
13961394
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
13971395
)
13981396
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
1399-
(loop_state, bsw, weights), _ = run_repeats_scanned(self, (loop_state, bsw, weights))
1400-
(loop_state, bsw, weights), _ = run_bubbles_scanned(self, (loop_state, bsw, weights))
1397+
(loop_state, bsw), _ = run_repeats_scanned(self, (loop_state, bsw))
1398+
(loop_state, bsw), _ = run_bubbles_scanned(self, (loop_state, bsw))
14011399
else:
14021400
for _ in range(self.config.num_pipeline_repeats):
1403-
(loop_state, bsw, weights), _ = run_one_repeat_scannable(self, (loop_state, bsw, weights))
1401+
(loop_state, bsw), _ = run_one_repeat_scannable(self, (loop_state, bsw))
14041402
for _ in range(bubble_iterations):
1405-
(loop_state, bsw, weights), _ = run_iteration_scannable(self, loop_state, bsw, weights)
1403+
(loop_state, bsw), _ = run_iteration_scannable(self, loop_state, bsw)
14061404

14071405
final_output = self.realign_output_microbatches(loop_state["state_io"])
14081406
final_output = jnp.reshape(

src/maxtext/utils/pipeline_utils.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -194,56 +194,54 @@ def create_gradient_accumulation_scan(
194194
"""
195195

196196
@functools.partial(jax.custom_vjp)
197-
def run_single_microbatch_custom(lightweight_state, bsw, weights, pos_arg, seg_arg):
198-
return run_single_microbatch_custom_fwd(lightweight_state, bsw, weights, pos_arg, seg_arg)[0]
197+
def run_single_microbatch_custom(lightweight_state, bsw, pos_arg, seg_arg):
198+
return run_single_microbatch_custom_fwd(lightweight_state, bsw, pos_arg, seg_arg)[0]
199199

200-
def run_single_microbatch_custom_fwd(lightweight_state, bsw, weights, pos_arg, seg_arg):
201-
def _run(l, b, w):
200+
def run_single_microbatch_custom_fwd(lightweight_state, bsw, pos_arg, seg_arg):
201+
def _run(l, b):
202202
out = model.run_one_iteration(
203-
l, b, w, pos_arg, seg_arg, deterministic, model_mode, logical_partition_spec=logical_partition_spec
203+
l, b, pos_arg, seg_arg, deterministic, model_mode, logical_partition_spec=logical_partition_spec
204204
)
205-
return out, b, w
205+
return out, b
206206

207207
# Rematerialize the inner step to save activation memory
208208
_run_remat = jax.remat(_run, prevent_cse=False, policy=model.get_pipeline_remat_policy())
209-
out, vjp_fun = jax.vjp(_run_remat, lightweight_state, bsw, weights)
209+
out, vjp_fun = jax.vjp(_run_remat, lightweight_state, bsw)
210210
return out, vjp_fun
211211

212212
def run_single_microbatch_custom_bwd(res, g_out):
213213
vjp_fun = res
214-
d_l, d_b, d_w = vjp_fun(g_out)
215-
return d_l, d_b, d_w, None, None
214+
d_l, d_b = vjp_fun(g_out)
215+
return d_l, d_b, None, None
216216

217217
run_single_microbatch_custom.defvjp(run_single_microbatch_custom_fwd, run_single_microbatch_custom_bwd)
218218

219219
@functools.partial(jax.custom_vjp)
220-
def run_pipeline_microbatches_custom(loop_state, bsw, weights, positions, segment_ids):
221-
return run_pipeline_microbatches_custom_fwd(loop_state, bsw, weights, positions, segment_ids)[0]
220+
def run_pipeline_microbatches_custom(loop_state, bsw, positions, segment_ids):
221+
return run_pipeline_microbatches_custom_fwd(loop_state, bsw, positions, segment_ids)[0]
222222

223-
def run_pipeline_microbatches_custom_fwd(loop_state, bsw, weights, positions, segment_ids):
223+
def run_pipeline_microbatches_custom_fwd(loop_state, bsw, positions, segment_ids):
224224
final_lightweight, scan_vjp_fun = jax.vjp(
225-
lambda l, b, w: jax.lax.scan(
226-
lambda carry, _: (run_single_microbatch_custom(carry, b, w, positions, segment_ids)[0], None),
225+
lambda l, b: jax.lax.scan(
226+
lambda carry, _: (run_single_microbatch_custom(carry, b, positions, segment_ids)[0], None),
227227
l,
228228
None,
229229
length=length,
230230
)[0],
231231
loop_state,
232232
bsw,
233-
weights,
234233
)
235234

236-
return (final_lightweight, bsw, weights), scan_vjp_fun
235+
return (final_lightweight, bsw), scan_vjp_fun
237236

238237
def run_pipeline_microbatches_custom_bwd(residuals, g_final_state):
239238
scan_vjp_fun = residuals
240-
g_lightweight, g_bsw, g_weights = g_final_state
241-
d_init_lightweight, d_init_bsw, d_init_weights = scan_vjp_fun(g_lightweight)
239+
g_lightweight, g_bsw = g_final_state
240+
d_init_lightweight, d_init_bsw = scan_vjp_fun(g_lightweight)
242241

243242
d_init_bsw = jax.tree.map(lambda d, g: d + g if hasattr(d, "shape") else d, d_init_bsw, g_bsw)
244-
d_init_weights = jax.tree.map(lambda d, g: d + g if hasattr(d, "shape") else d, d_init_weights, g_weights)
245243

246-
return (d_init_lightweight, d_init_bsw, d_init_weights, None, None)
244+
return (d_init_lightweight, d_init_bsw, None, None)
247245

248246
run_pipeline_microbatches_custom.defvjp(run_pipeline_microbatches_custom_fwd, run_pipeline_microbatches_custom_bwd)
249247
return run_pipeline_microbatches_custom
@@ -259,6 +257,7 @@ def create_rematerialized_pipeline_stage(
259257
physical_partition_spec,
260258
positions,
261259
segment_ids,
260+
pipeline_weights,
262261
):
263262
"""Builds a memory-checkpointed execution block for a single pipeline stage.
264263
@@ -279,16 +278,17 @@ def create_rematerialized_pipeline_stage(
279278
physical_partition_spec: Rules for physical device mesh mappings (used in prefetching).
280279
positions: Position IDs for the sequence.
281280
segment_ids: Segment/Attention routing IDs for the sequence.
281+
pipeline_weights: The fully gathered pipeline weights explicitly passed via closure.
282282
283283
Returns:
284284
A function decorated with `nn.remat` that takes `(model, loop_state)` and returns
285285
the updated `loop_state`.
286286
"""
287287

288-
def execute_pipeline_stage(model, loop_state_and_bsw_and_weights):
289-
loop_state, bsw, weights = loop_state_and_bsw_and_weights
288+
def execute_pipeline_stage(model, loop_state_and_bsw):
289+
loop_state, bsw = loop_state_and_bsw
290290
# Retrieve the specific weights needed for this pipeline chunk
291-
bsw = model.weight_prefetching(weights, physical_partition_spec, loop_state["loop_iteration"])
291+
bsw = model.weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292292

293293
if model.config.scan_pipeline_iterations:
294294
scan_microbatches_fn = create_gradient_accumulation_scan(
@@ -298,11 +298,11 @@ def execute_pipeline_stage(model, loop_state_and_bsw_and_weights):
298298
model_mode=model_mode,
299299
logical_partition_spec=logical_partition_spec,
300300
)
301-
loop_state, bsw, weights = scan_microbatches_fn(loop_state, bsw, weights, positions, segment_ids)
301+
loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids)
302302
else:
303303
for _ in range(length):
304-
(loop_state, bsw, weights), _ = run_iteration_scannable(model, loop_state, bsw, weights)
305-
return (loop_state, bsw, weights), None
304+
(loop_state, bsw), _ = run_iteration_scannable(model, loop_state, bsw)
305+
return (loop_state, bsw), None
306306

307307
return nn.remat(
308308
execute_pipeline_stage,

0 commit comments

Comments
 (0)