Skip to content

Commit 24a3c79

Browse files
Kevin WangGoogle-ML-Automation
authored andcommitted
Offload all checkpoints to host and remove unused mesh axes.
PiperOrigin-RevId: 903412342
1 parent d8f10cf commit 24a3c79

2 files changed

Lines changed: 8 additions & 11 deletions

File tree

src/maxtext/configs/models/deepseek3-671b-batchsplit.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,30 +59,27 @@ use_batch_split_schedule: True
5959
shard_mode: "explicit"
6060
remove_size_one_mesh_axis_from_type: False
6161
override_logical_axis_rules: True
62-
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'expert', 'context']
63-
data_sharding: [['data', 'stage', 'fsdp', 'expert', 'context']]
62+
mesh_axes: ['data', 'fsdp', 'expert', 'context']
63+
data_sharding: [['data', 'fsdp', 'expert', 'context']]
6464
logical_axis_rules: [
6565
['activation_batch', ['data', 'fsdp', 'expert', 'context']],
6666
['activation_batch_moe', ['data', 'fsdp', 'expert', 'context']],
67-
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert', 'context']],
67+
['activation_embed_and_logits_batch', ['data', 'fsdp', 'expert', 'context']],
6868
['activation_kv_batch', ['data', 'fsdp', 'expert', 'context']],
6969
['activation_norm_length', []],
7070
['activation_norm_length_moe', []],
7171
['activation_heads', []],
72-
['activation_stage', 'stage'],
7372
['embed', ['fsdp']],
7473
['embed_moe', ['fsdp']],
7574
['embed_no_exp', ['fsdp']],
7675
['embed_no_exp_moe', ['fsdp']],
7776
['q_lora', ['fsdp']],
7877
['kv_lora', ['fsdp']],
79-
['layers', 'stage'],
8078
['q_lora_up_proj', []],
8179
['kv_lora_up_proj', []],
8280
['q_heads', []],
8381
['kv_heads', []],
8482
['heads', []],
8583
['mlp', []],
8684
['expert_only', ['expert']],
87-
['diloco', 'diloco'],
8885
]

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ def process_layer_scannable(carry, layer_idx, group_id):
714714
pairwise_swap_and_negate_mask=yarn_mask,
715715
)
716716
# Offload to host memory.
717-
for residual_name in ("mlpwi_0", "mlpwi_1"):
717+
for residual_name in ("mlpwi_0", "mlpwi_1", "attn_out", "layer_inputs"):
718718
r = res.pop(residual_name)
719719
r = jax.tree.map(lambda x: jax.device_put(x, jax.typeof(x).sharding.with_memory_kind("pinned_host")), r)
720720
res[residual_name] = r
@@ -736,7 +736,7 @@ def process_layer_scannable(carry, layer_idx, group_id):
736736
pairwise_swap_and_negate_mask=yarn_mask,
737737
)
738738
# Offload first layer residuals to host memory.
739-
for residual_name in ("mlpwi_0", "mlpwi_1"):
739+
for residual_name in ("mlpwi_0", "mlpwi_1", "attn_out", "layer_inputs"):
740740
r = first_res.pop(residual_name)
741741
r = jax.tree.map(lambda x: jax.device_put(x, jax.typeof(x).sharding.with_memory_kind("pinned_host")), r)
742742
first_res[residual_name] = r
@@ -829,7 +829,7 @@ def process_layer_bwd_scannable(carry, res_and_layer_idx, group_id):
829829
next_next_ws_grad = all_reduce_ws_grad_dcn(next_next_ws_grad, mesh)
830830
all_layer_ws_grad = insert_layer_ws_grad(all_layer_ws_grad, next_next_ws_grad, layer_idx + 2, cfg.param_scan_axis)
831831
# Get residuals from host.
832-
for residual_name in ("mlpwi_0", "mlpwi_1"):
832+
for residual_name in ("mlpwi_0", "mlpwi_1", "attn_out", "layer_inputs"):
833833
r = res.pop(residual_name)
834834
r = jax.tree.map(lambda x: jax.device_put(x, jax.typeof(x).sharding.with_memory_kind("device")), r)
835835
res[residual_name] = r
@@ -890,7 +890,7 @@ def process_layer_bwd_scannable(carry, res_and_layer_idx, group_id):
890890
prev_prev_ws = gather_weights(extract_layer_weights(all_weights, num_layers - 3, cfg.param_scan_axis), mesh)
891891
ws_grad = reduce_scatter_ws_grad(ws_grad, mesh)
892892
# Get residuals from host.
893-
for residual_name in ("mlpwi_0", "mlpwi_1"):
893+
for residual_name in ("mlpwi_0", "mlpwi_1", "attn_out", "layer_inputs"):
894894
r = last_last_res.pop(residual_name)
895895
r = jax.tree.map(lambda x: jax.device_put(x, jax.typeof(x).sharding.with_memory_kind("device")), r)
896896
last_last_res[residual_name] = r
@@ -931,7 +931,7 @@ def process_layer_bwd_scannable(carry, res_and_layer_idx, group_id):
931931
third_ws_grad = all_reduce_ws_grad_dcn(third_ws_grad, mesh)
932932
all_layer_ws_grad = insert_layer_ws_grad(all_layer_ws_grad, third_ws_grad, 2, cfg.param_scan_axis)
933933
# Get residuals from host.
934-
for residual_name in ("mlpwi_0", "mlpwi_1"):
934+
for residual_name in ("mlpwi_0", "mlpwi_1", "attn_out", "layer_inputs"):
935935
r = first_res.pop(residual_name)
936936
r = jax.tree.map(lambda x: jax.device_put(x, jax.typeof(x).sharding.with_memory_kind("device")), r)
937937
first_res[residual_name] = r

0 commit comments

Comments
 (0)