Skip to content

Commit eb931a1

Browse files
committed
add lax.cond and dtype control
1 parent 60dbacd commit eb931a1

3 files changed

Lines changed: 21 additions & 6 deletions

File tree

src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ logical_axis_rules: [
4141
['activation_attn_length', ['context', 'expert']],
4242
['activation_q_length', ['context', 'expert']],
4343
['activation_attn_embed', ['tensor']],
44+
['activation_norm_length', ['context']],
45+
['activation_norm_length_moe', ['context']],
4446
['activation_embed', ['tensor']],
4547
['activation_embed_moe', ['tensor']],
4648
['activation_mlp', ['tensor']],

src/maxtext/layers/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -970,8 +970,8 @@ def init_states(self, inputs):
970970
def _init_empty_bsw_buffers(variables):
971971
# BSW requires two buffers (current and next) for the sliding window
972972
return (
973-
jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables),
974-
jax.tree.map(lambda x: jnp.zeros_like(x[0]), variables),
973+
jax.tree.map(lambda x: jnp.zeros_like(x[0], dtype=jnp.bfloat16), variables),
974+
jax.tree.map(lambda x: jnp.zeros_like(x[0], dtype=jnp.bfloat16), variables),
975975
)
976976

977977
if self.is_initializing():

src/maxtext/utils/pipeline_utils.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@
2222
import jax.numpy as jnp
2323

2424

25+
def to_bf16(tree):
26+
"""Casts all floating-point leaves in a PyTree to bfloat16."""
27+
return jax.tree.map(
28+
lambda t: t.astype(jnp.bfloat16) if hasattr(t, "dtype") and jnp.issubdtype(t.dtype, jnp.floating) else t, tree
29+
)
30+
31+
32+
def to_f32(tree):
33+
"""Casts all floating-point leaves in a PyTree to float32."""
34+
return jax.tree.map(
35+
lambda t: t.astype(jnp.float32) if hasattr(t, "dtype") and jnp.issubdtype(t.dtype, jnp.floating) else t, tree
36+
)
37+
38+
2539
def get_mesh_axis_dim_indices(physical_partition_spec, axis_name="fsdp"):
2640
"""Finds the tensor dimension index sharded across a specific physical mesh axis.
2741
@@ -310,7 +324,7 @@ def execute_pipeline_stage_pure(loop_state, w_curr, pipeline_weights):
310324
def execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights):
311325
# Prefetch FSDP-sharded weights for the upcoming pipeline repeat
312326
w_next = model.weight_prefetching(
313-
pipeline_weights,
327+
to_bf16(pipeline_weights),
314328
physical_partition_spec,
315329
loop_state["loop_iteration"],
316330
)
@@ -332,9 +346,8 @@ def execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights):
332346
)
333347
# Execute the forward pass of the microbatches and generate its VJP.
334348
# The VJP captures necessary checkpoints to evaluate gradients later.
335-
(loop_state, bsw), scan_microbatches_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
349+
(loop_state, _), scan_microbatches_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids)
336350
# Discard the old weights (w_curr) and advance w_next to act as the current weights in the next iteration
337-
_, w_next = bsw
338351
return (loop_state, w_next), (scan_microbatches_vjp, weight_prefetching_t)
339352

340353
def execute_pipeline_stage_pure_bwd(residuals, g_outputs):
@@ -349,7 +362,7 @@ def execute_pipeline_stage_pure_bwd(residuals, g_outputs):
349362
# Apply the linear transpose of the weight prefetch to execute the reduce-scatter
350363
# This maps the gradients of the gathered weights back to the FSDP-sharded parameter space
351364
g_w_curr, g_w_next = g_bsw
352-
(g_pipeline_weights,) = weight_prefetching_t(g_w_next)
365+
(g_pipeline_weights,) = weight_prefetching_t(to_f32(g_w_next))
353366
# Return gradients corresponding to the three original inputs of execute_pipeline_stage_pure
354367
return g_loop_state, g_w_curr, g_pipeline_weights
355368

0 commit comments

Comments
 (0)