diff --git a/src/maxtext/models/gemma4.py b/src/maxtext/models/gemma4.py index 1803ec705c..b4af07122e 100644 --- a/src/maxtext/models/gemma4.py +++ b/src/maxtext/models/gemma4.py @@ -299,7 +299,7 @@ def __init__( else: self.post_ffw_norm = None - self.layer_scalar = nnx.Param(jnp.ones((1,), dtype=config.dtype), sharding=(None,)) + self.layer_scalar = nnx.Param(jnp.ones((1,), dtype=config.weight_dtype), sharding=(None,)) if model_mode == MODEL_MODE_PREFILL: self.activation_axis_names = ("activation_batch", "prefill_activation_norm_length", "activation_embed") diff --git a/src/maxtext/utils/vocabulary_tiling.py b/src/maxtext/utils/vocabulary_tiling.py index 06184cff87..e7b155416c 100644 --- a/src/maxtext/utils/vocabulary_tiling.py +++ b/src/maxtext/utils/vocabulary_tiling.py @@ -224,10 +224,10 @@ def _bwd_scan_body(grad_params_acc, chunk_data): _bwd_scan_body, initial_grad_params_acc, (reshaped_hidden_states, reshaped_labels, reshaped_segmentation) ) grad_reshaped_hidden_states = _maybe_shard_with_name(grad_reshaped_hidden_states, reshaped_hidden_spec) - # TODO (chengnuojin): we may want to convert grad_params to bf16 to save memory - # grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params) # Chain-rule to accumulate gradients grad_params = jax.tree_util.tree_map(lambda g: g * loss_cotangent, grad_params) + # Cast cotangents back to each primal's dtype; custom_vjp requires dtype match. + grad_params = jax.tree_util.tree_map(lambda x, y: y.astype(x.dtype), gathered_params, grad_params) # Give back sharding constraint grad_reshaped_hidden_states = _reshape(grad_reshaped_hidden_states, (batch_size, seq_len, emb_dim), hidden_spec) return ( diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 4230c46174..75b5f540b4 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -974,3 +974,22 @@ def test_qk_clip(self): "qk_clip_threshold=100", ) ) + + @pytest.mark.cpu_only + def test_vocab_tiling_bf16(self): + """test vocab_tiling when weight_dtype=bfloat16""" + compiled_trainstep_file = "/tmp/test_vocab_tiling_bf16.pickle" + train_compile_main( + ( + "", + get_test_config_path(), + f"compiled_trainstep_file={compiled_trainstep_file}", + "compile_topology=v5p-8", + "compile_topology_num_slices=1", + "base_num_decoder_layers=2", + "per_device_batch_size=2", + "max_target_length=1024", + "num_vocab_tiling=4", + "weight_dtype=bfloat16", + ) + )