Skip to content

Commit 680f1ad

Browse files
committed
remove additional AG
1 parent cc1c7df commit 680f1ad

2 files changed

Lines changed: 49 additions & 34 deletions

File tree

src/maxtext/layers/pipeline.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def get_microbatch_and_repeat_ids(self, loop_iteration):
184184
non-circular"""
185185
# Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages
186186
microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0)
187+
microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage")))
187188
microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches
188189
repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches
189190
return microbatch_ids, repeat_ids
@@ -1006,8 +1007,12 @@ def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim):
10061007

10071008
def _gather_one(x, i):
10081009
idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim))
1009-
replicated_sharding = NamedSharding(self.mesh, P())
1010-
return x.at[idx].get(out_sharding=replicated_sharding)
1010+
positions_sharding = (
1011+
create_sharding(self.mesh, (None, "layers", "activation_length"))
1012+
if self.config.shard_mode == ShardMode.EXPLICIT
1013+
else None
1014+
)
1015+
return x.at[idx].get(out_sharding=positions_sharding)
10111016

10121017
return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids)
10131018

@@ -1229,7 +1234,7 @@ def _apply_sharding_hint(weight, pspec):
12291234
weight,
12301235
sharding_name,
12311236
shard_mode=self.config.shard_mode,
1232-
debug_sharding=self.config.shard_mode,
1237+
debug_sharding=self.config.debug_sharding,
12331238
extra_stack_level=0,
12341239
)
12351240

@@ -1239,7 +1244,7 @@ def _apply_sharding_hint(weight, pspec):
12391244
return _from_repeat_weights_to_bsw_shardmap(repeat_weights, physical_partition_spec, axes_to_gather=axes_to_gather)
12401245
return _from_repeat_weights_to_bsw_hint(repeat_weights)
12411246

1242-
def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1247+
def both_weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12431248
"""Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps.
12441249
12451250
By gathering weights for `loop_iteration + 1` right now, the network communication
@@ -1250,7 +1255,16 @@ def weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
12501255
nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
12511256
bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec)
12521257
bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec)
1253-
return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw")
1258+
return bsw_0, bsw_1
1259+
1260+
def one_weight_prefetching(self, weights, physical_partition_spec, loop_iteration):
1261+
"""Triggers asynchronous FSDP-like all-gathers for the next pipeline steps.
1262+
1263+
By gathering weights for `loop_iteration + 1` right now, the network communication
1264+
can overlap with the compute happening in `loop_iteration`.
1265+
"""
1266+
repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1)
1267+
return self.from_repeat_weights_to_bsw(repeat_weights, physical_partition_spec)
12541268

12551269
def run_one_iteration(self, loop_state, bsw, positions, segment_ids, deterministic, model_mode, logical_partition_spec):
12561270
"""Executes the forward/backward logic for a single microbatch inside the pipeline.
@@ -1389,18 +1403,12 @@ def run_iteration_scannable(model, loop_state, bsw):
13891403
run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches)
13901404
run_bubbles_scannable = base_scannable(length=bubble_iterations)
13911405

1392-
if self.config.scan_pipeline_repeats:
1393-
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1394-
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
1395-
)
1396-
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
1397-
(loop_state, bsw), _ = run_repeats_scanned(self, (loop_state, bsw))
1398-
(loop_state, bsw), _ = run_bubbles_scanned(self, (loop_state, bsw))
1399-
else:
1400-
for _ in range(self.config.num_pipeline_repeats):
1401-
(loop_state, bsw), _ = run_one_repeat_scannable(self, (loop_state, bsw))
1402-
for _ in range(bubble_iterations):
1403-
(loop_state, bsw), _ = run_iteration_scannable(self, loop_state, bsw)
1406+
run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan(
1407+
pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats
1408+
)
1409+
run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1)
1410+
(loop_state, w_curr), _ = run_repeats_scanned(self, (loop_state, bsw[0]))
1411+
(loop_state, _), _ = run_bubbles_scanned(self, (loop_state, w_curr))
14041412

14051413
final_output = self.realign_output_microbatches(loop_state["state_io"])
14061414
final_output = jnp.reshape(

src/maxtext/utils/pipeline_utils.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -286,23 +286,30 @@ def create_rematerialized_pipeline_stage(
286286
"""
287287

288288
def execute_pipeline_stage(model, loop_state_and_bsw):
289-
loop_state, bsw = loop_state_and_bsw
290-
# Retrieve the specific weights needed for this pipeline chunk
291-
bsw = model.weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292-
293-
if model.config.scan_pipeline_iterations:
294-
scan_microbatches_fn = create_gradient_accumulation_scan(
295-
model=model,
296-
length=length,
297-
deterministic=deterministic,
298-
model_mode=model_mode,
299-
logical_partition_spec=logical_partition_spec,
300-
)
301-
loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids)
302-
else:
303-
for _ in range(length):
304-
(loop_state, bsw), _ = run_iteration_scannable(model, loop_state, bsw)
305-
return (loop_state, bsw), None
289+
loop_state, w_curr = loop_state_and_bsw
290+
# # Retrieve the specific weights needed for this pipeline chunk
291+
# bsw = model.both_weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"])
292+
w_next = jax.remat(
293+
model.one_weight_prefetching,
294+
static_argnums=(1,),
295+
policy=jax.checkpoint_policies.nothing_saveable,
296+
)(
297+
pipeline_weights,
298+
physical_partition_spec,
299+
loop_state["loop_iteration"],
300+
)
301+
bsw = (w_curr, w_next)
302+
scan_microbatches_fn = create_gradient_accumulation_scan(
303+
model=model,
304+
length=length,
305+
deterministic=deterministic,
306+
model_mode=model_mode,
307+
logical_partition_spec=logical_partition_spec,
308+
)
309+
loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids)
310+
w_curr, w_next = bsw
311+
del w_curr
312+
return (loop_state, w_next), None
306313

307314
return nn.remat(
308315
execute_pipeline_stage,

0 commit comments

Comments
 (0)