Skip to content

Commit ac07a32

Browse files
Merge pull request #3727 from AI-Hypercomputer:aireen/vocab_dtype
PiperOrigin-RevId: 904656787
2 parents 532c8b3 + a98df90 commit ac07a32

3 files changed

Lines changed: 22 additions & 3 deletions

File tree

src/maxtext/models/gemma4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def __init__(
299299
else:
300300
self.post_ffw_norm = None
301301

302-
self.layer_scalar = nnx.Param(jnp.ones((1,), dtype=config.dtype), sharding=(None,))
302+
self.layer_scalar = nnx.Param(jnp.ones((1,), dtype=config.weight_dtype), sharding=(None,))
303303

304304
if model_mode == MODEL_MODE_PREFILL:
305305
self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed")

src/maxtext/utils/vocabulary_tiling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,10 @@ def _bwd_scan_body(grad_params_acc, chunk_data):
224224
_bwd_scan_body, initial_grad_params_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation)
225225
)
226226
grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec)
227-
# TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory
228-
# grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params)
229227
# Chain-rule to accumulate gradients
230228
grad_params = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_params)
229+
# Cast cotangents back to each primal's dtype; custom_vjp requires dtype match.
230+
grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params)
231231
# Give back sharding constraint
232232
grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec)
233233
return (

tests/unit/train_compile_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,3 +955,22 @@ def test_qk_clip(self):
955955
"qk_clip_threshold=100",
956956
)
957957
)
958+
959+
@pytest.mark.cpu_only
960+
def test_vocab_tiling_bf16(self):
961+
"""test vocab_tiling when weight_dtype=bfloat16"""
962+
compiled_trainstep_file = "/tmp/test_vocab_tiling_bf16.pickle"
963+
train_compile_main(
964+
(
965+
"",
966+
get_test_config_path(),
967+
f"compiled_trainstep_file={compiled_trainstep_file}",
968+
"compile_topology=v5p-8",
969+
"compile_topology_num_slices=1",
970+
"base_num_decoder_layers=2",
971+
"per_device_batch_size=2",
972+
"max_target_length=1024",
973+
"num_vocab_tiling=4",
974+
"weight_dtype=bfloat16",
975+
)
976+
)

0 commit comments

Comments
 (0)