Skip to content

Commit 2b3da04

Browse files
committed
remove additional AG
1 parent da37ebc commit 2b3da04

2 files changed

Lines changed: 41 additions & 31 deletions

File tree

src/maxtext/layers/pipeline.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,7 @@ def _apply_sharding_hint(weight, pspec):
12391239
return _from_repeat_weights_to_bsw_shardmap(repeat_weights, physical_partition_spec, axes_to_gather=axes_to_gather)
12401240
return _from_repeat_weights_to_bsw_hint(repeat_weights)
12411241

1242-
def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1242+
def both_weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12431243
"""Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
12441244
12451245
By gathering weights for `loop_iteration + 1` right now, the network communication
@@ -1250,7 +1250,16 @@ def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12501250
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
12511251
bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec)
12521252
bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec)
1253-
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
1253+
return bsw_0, bsw_1
1254+
1255+
def one_weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1256+
"""Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
1257+
1258+
By gathering weights for `loop_iteration + 1` right now, the network communication
1259+
can overlap with the compute happening in `loop_iteration`.
1260+
"""
1261+
repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
1262+
return self.from_repeat_weights_to_bsw(repeat_weights, physical_partition_spec)
12541263

12551264
def run_one_iteration(self, loop_state, bsw, positions, segment_ids, deterministic, model_mode, logical_partition_spec):
12561265
"""Executes the forward/backward logic for a single microbatch inside the pipeline.
@@ -1389,18 +1398,12 @@ def run_iteration_scannable(model, loop_state, bsw):
13891398
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
13901399
run_bubbles_scannable = base_scannable(length=bubble_iterations)
13911400

1392-
if self.config.scan_pipeline_repeats:
1393-
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1394-
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
1395-
)
1396-
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
1397-
(loop_state, bsw), _ = run_repeats_scanned(self, (loop_state, bsw))
1398-
(loop_state, bsw), _ = run_bubbles_scanned(self, (loop_state, bsw))
1399-
else:
1400-
for _ in range(self.config.num_pipeline_repeats):
1401-
(loop_state, bsw), _ = run_one_repeat_scannable(self, (loop_state, bsw))
1402-
for _ in range(bubble_iterations):
1403-
(loop_state, bsw), _ = run_iteration_scannable(self, loop_state, bsw)
1401+
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1402+
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
1403+
)
1404+
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
1405+
(loop_state, w_curr), _ = run_repeats_scanned(self, (loop_state, bsw[0]))
1406+
(loop_state, _), _ = run_bubbles_scanned(self, (loop_state, w_curr))
14041407

14051408
final_output = self.realign_output_microbatches(loop_state["state_io"])
14061409
final_output = jnp.reshape(

src/maxtext/utils/pipeline_utils.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -286,23 +286,30 @@ def create_rematerialized_pipeline_stage(
286286
"""
287287

288288
def execute_pipeline_stage(model, loop_state_and_bsw):
289-
loop_state, bsw = loop_state_and_bsw
290-
# Retrieve the specific weights needed for this pipeline chunk
291-
bsw = model.weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292-
293-
if model.config.scan_pipeline_iterations:
294-
scan_microbatches_fn = create_gradient_accumulation_scan(
295-
model=model,
296-
length=length,
297-
deterministic=deterministic,
298-
model_mode=model_mode,
299-
logical_partition_spec=logical_partition_spec,
300-
)
301-
loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids)
302-
else:
303-
for _ in range(length):
304-
(loop_state, bsw), _ = run_iteration_scannable(model, loop_state, bsw)
305-
return (loop_state, bsw), None
289+
loop_state, w_curr = loop_state_and_bsw
290+
# # Retrieve the specific weights needed for this pipeline chunk
291+
# bsw = model.both_weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292+
w_next = jax.remat(
293+
model.one_weight_prefetching,
294+
static_argnums=(1,),
295+
policy=jax.checkpoint_policies.nothing_saveable,
296+
)(
297+
pipeline_weights,
298+
physical_partition_spec,
299+
loop_state["loop_iteration"],
300+
)
301+
bsw = (w_curr, w_next)
302+
scan_microbatches_fn = create_gradient_accumulation_scan(
303+
model=model,
304+
length=length,
305+
deterministic=deterministic,
306+
model_mode=model_mode,
307+
logical_partition_spec=logical_partition_spec,
308+
)
309+
loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids)
310+
w_curr, w_next = bsw
311+
del w_curr
312+
return (loop_state, w_next), None
306313

307314
return nn.remat(
308315
execute_pipeline_stage,

0 commit comments

Comments
 (0)