Skip to content

Commit a98df90

Browse files
committed
fix dtype mismatch when using vocab_tiling
1 parent 1907615 commit a98df90

3 files changed

Lines changed: 22 additions & 3 deletions

File tree

src/maxtext/models/gemma4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def __init__(
299299
else:
300300
self.post_ffw_norm = None
301301

302-
self.layer_scalar = nnx.Param(jnp.ones((1,), dtype=config.dtype), sharding=(None,))
302+
self.layer_scalar = nnx.Param(jnp.ones((1,), dtype=config.weight_dtype), sharding=(None,))
303303

304304
if model_mode == MODEL_MODE_PREFILL:
305305
self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")

src/maxtext/utils/vocabulary_tiling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,10 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
224224
_bwd_scan_body, initial_grad_params_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
225225
)
226226
grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec)
227-
# TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory
228-
# grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params)
229227
# Chain-rule to accumulate gradients
230228
grad_params = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_params)
229+
# Cast cotangents back to each primal's dtype; custom_vjp requires dtype match.
230+
grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params)
231231
# Give back sharding constraint
232232
grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)
233233
return (

tests/unit/train_compile_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,3 +974,22 @@ def test_qk_clip(self):
974974
"qk_clip_threshold=100",
975975
)
976976
)
977+
978+
@pytest.mark.cpu_only
979+
def test_vocab_tiling_bf16(self):
980+
"""test vocab_tiling when weight_dtype=bfloat16"""
981+
compiled_trainstep_file = "/tmp/test_vocab_tiling_bf16.pickle"
982+
train_compile_main(
983+
(
984+
"",
985+
get_test_config_path(),
986+
f"compiled_trainstep_file={compiled_trainstep_file}",
987+
"compile_topology=v5p-8",
988+
"compile_topology_num_slices=1",
989+
"base_num_decoder_layers=2",
990+
"per_device_batch_size=2",
991+
"max_target_length=1024",
992+
"num_vocab_tiling=4",
993+
"weight_dtype=bfloat16",
994+
)
995+
)

0 commit comments

Comments
 (0)