diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 5e59a0f4be..aae5fc318f 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -1158,8 +1158,8 @@ position_id_per_seconds: 25 subslice_shape: "" # NNX -enable_nnx: False -pure_nnx_decoder: False +enable_nnx: True +pure_nnx_decoder: True pure_nnx: False ################################## Qwen3-Next Specific Configs ################################## diff --git a/src/maxtext/layers/initializers.py b/src/maxtext/layers/initializers.py index 20baf9a633..e7ea2094db 100644 --- a/src/maxtext/layers/initializers.py +++ b/src/maxtext/layers/initializers.py @@ -94,6 +94,16 @@ def variable_to_logically_partitioned(variable: nnx.VariableState): out_sharding = metadata["sharding"] if out_sharding is not None: + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + scan_axis = metadata.get("param_scan_axis", 0) if variable.type == nnx.Param else 0 + + sharding_list = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + if partition_name not in sharding_list: + sharding_list.insert(scan_axis, partition_name) + + out_sharding = tuple(sharding_list) + return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args] variable.value, out_sharding, # type: ignore[arg-type] diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 3c8a601201..e5839214fd 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -46,9 +46,11 @@ from maxtext.models import ( deepseek, deepseek_batchsplit, + deepseek_batchsplit_fp8, gemma, gemma2, gemma3, + gemma4, gpt3, gpt_oss, llama2, @@ -70,7 +72,7 @@ class NNXDecoderLayer(nnx.Module): """ - Transformer decoder layer converted to NNX. + Transformer decoder layer converted to NNX """ def __init__( @@ -169,7 +171,7 @@ def __call__( if self.model_mode == MODEL_MODE_PREFILL: logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") else: - logical_axis_names = ("activation_batch", "activation_length", "activation_embed") + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") inputs = _maybe_shard_with_logical(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") @@ -258,14 +260,6 @@ def __init__( decoder_block_classes = self.get_decoder_layers() - self.decoder_norm = self.get_norm_layer(num_features=config.emb_dim, rngs=rngs)( - dtype=config.dtype, - weight_dtype=config.weight_dtype, - epsilon=config.normalization_layer_epsilon, - kernel_axes=("norm",), - parameter_memory_host_offload=config.parameter_memory_host_offload, - ) - if config.trainable_position_size > 0: self.position_embedder = Embed( num_embeddings=config.trainable_position_size, @@ -278,9 +272,15 @@ def __init__( ) self.dropout = linears.Dropout(rate=config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) - self.positional_embedding = PositionalEmbedding(embedding_dims=config.base_emb_dim) + self.decoder_norm = self.get_norm_layer(num_features=config.emb_dim, rngs=rngs)( + dtype=config.dtype, + weight_dtype=config.weight_dtype, + epsilon=config.normalization_layer_epsilon, + kernel_axes=("norm",), + parameter_memory_host_offload=config.parameter_memory_host_offload, + ) if not config.logits_via_embedding: self.logits_dense = linears.DenseGeneral( in_features_shape=config.emb_dim, @@ -297,18 +297,61 @@ def __init__( self.scanned_layers = None self.is_deepseek = self.config.decoder_block == DecoderBlockType.DEEPSEEK self.is_gemma3 = self.config.decoder_block == DecoderBlockType.GEMMA3 + self.is_gemma4 = self.config.decoder_block == DecoderBlockType.GEMMA4 if self.config.scan_layers: if self.is_deepseek: assert len(decoder_block_classes) == 2 dense_cls, moe_cls = decoder_block_classes - num_dense = config.first_num_dense_layers - self.dense_layers = self._create_scanned_layers(dense_cls, length=num_dense, rngs=rngs) - - num_moe = config.num_decoder_layers - config.first_num_dense_layers - - self.moe_layer = self._create_scanned_layers(moe_cls, length=num_moe, rngs=rngs) + if config.engram_layers: + # 1. Create Dense Chunks (Direct setattr, NO nnx.Dict) + current_idx = 0 + while current_idx < config.first_num_dense_layers: + if current_idx in config.engram_layers: + layer_name = f"dense_layers_engram_{current_idx}" + setattr(self, layer_name, self._create_single_layer(dense_cls, rngs, layer_idx=current_idx)) + current_idx += 1 + else: + next_boundary = self._find_next_boundary(current_idx, config.first_num_dense_layers, config.engram_layers) + chunk_name = f"dense_layers_{current_idx}_{next_boundary - 1}" + setattr( + self, + chunk_name, + self._create_scanned_layers( + dense_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs + ), + ) + current_idx = next_boundary + + # 2. Create MoE Chunks (Direct setattr, NO nnx.Dict) + current_idx = config.first_num_dense_layers + while current_idx < config.num_decoder_layers: + if current_idx in config.engram_layers: + layer_name = f"moe_layers_engram_{current_idx}" + setattr(self, layer_name, self._create_single_layer(moe_cls, rngs, layer_idx=current_idx)) + current_idx += 1 + else: + next_boundary = self._find_next_boundary(current_idx, config.num_decoder_layers, config.engram_layers) + chunk_name = f"moe_layers_{current_idx}_{next_boundary - 1}" + setattr( + self, + chunk_name, + self._create_scanned_layers( + moe_cls, length=(next_boundary - current_idx), metadata_axis_name=chunk_name, rngs=rngs + ), + ) + current_idx = next_boundary + else: + # Standard DeepSeek logic when Engrams are disabled + num_dense = config.first_num_dense_layers + self.dense_layers = self._create_scanned_layers( + dense_cls, length=num_dense, metadata_axis_name="dense_layers", rngs=rngs + ) + num_moe = config.num_decoder_layers - config.first_num_dense_layers + self.moe_layers = self._create_scanned_layers( + moe_cls, length=num_moe, metadata_axis_name="moe_layers", rngs=rngs + ) elif self.is_gemma3: attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) scan_length = config.num_decoder_layers // attention_pattern_length @@ -320,10 +363,29 @@ def __init__( RemattedGemma3Block = gemma3.Gemma3ScannableBlock if scan_length > 0: - self.layers = self._create_scanned_layers(RemattedGemma3Block, length=scan_length, rngs=rngs, **layer_kwargs) + self.layers = self._create_scanned_layers( + RemattedGemma3Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) self.layers_remainder = RemattedGemma3Block( config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs ) # pytype: disable=wrong-keyword-args + elif self.is_gemma4: + attention_pattern_length = len(gemma4.GEMMA4_ATTENTION_PATTERN) + scan_length = config.num_decoder_layers // attention_pattern_length + num_remaining_layers = config.num_decoder_layers % attention_pattern_length + layer_kwargs = {"num_of_layers": attention_pattern_length} + + rem_layer_kwargs = {"num_of_layers": num_remaining_layers} + + RemattedGemma4Block = gemma4.Gemma4ScannableBlock + + if scan_length > 0: + self.layers = self._create_scanned_layers( + RemattedGemma4Block, length=scan_length, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + self.layers_remainder = RemattedGemma4Block( + config=self.config, mesh=mesh, quant=self.quant, model_mode=self.model_mode, **rem_layer_kwargs, rngs=rngs + ) else: layer_cls = decoder_block_classes[0] num_layers = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) @@ -334,7 +396,13 @@ def __init__( "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - self.layers = self._create_scanned_layers(layer_cls, length=num_layers, rngs=rngs, **layer_kwargs) + if num_layers > 0: + self.layers = self._create_scanned_layers( + layer_cls, length=num_layers, metadata_axis_name="layers", rngs=rngs, **layer_kwargs + ) + else: + self.layers = nnx.List([]) + else: self.layers = nnx.List([]) @@ -351,6 +419,8 @@ def __init__( layer_kwargs = {} if config.decoder_block == DecoderBlockType.GEMMA3: layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} + elif config.decoder_block == DecoderBlockType.GEMMA4: + layer_kwargs = {"attention_type": gemma4.get_attention_type(layer_id=lyr)} elif config.decoder_block == DecoderBlockType.LLAMA4: layer_kwargs = { "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), @@ -383,34 +453,84 @@ def _create_single_layer(self, decoder_layer_class, rngs, **kwargs): ) return nnx_wrappers.ToNNX(layer_linen, rngs=rngs) - def _create_scanned_layers(self, decoder_layer_class, length: int, rngs: nnx.Rngs, **layer_kwargs): - """Creates a VMapped stack of layers, forcing parameter init for Compact modules.""" + def _create_scanned_layers( + self, decoder_layer_class, length: int, metadata_axis_name: str, rngs: nnx.Rngs, **layer_kwargs + ): + """Creates a scanned stack of layers using jax.lax.scan for memory-efficient initialization.""" + if length == 0: + return None + scan_axis = self.config.param_scan_axis + + # Fork rngs to get per-layer RNG states for scanning + try: + forked_rngs = rngs.fork(split=length) + except: # pylint: disable=bare-except + pass + + rngs_graphdef, rngs_state = nnx.split(forked_rngs) + + first_rng_state = jax.tree.map(lambda x: x[0], rngs_state) + ref_rngs = nnx.merge(rngs_graphdef, first_rng_state) + ref_layer = decoder_layer_class( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=ref_rngs, **layer_kwargs + ) + layer_graphdef, _, _ = nnx.split(ref_layer, nnx.Param, ...) + del ref_layer - def create_layer_fn(rng): + def scan_body(carry, rng_state_slice): + layer_rngs = nnx.merge(rngs_graphdef, rng_state_slice) layer = decoder_layer_class( - config=self.config, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, rngs=rng, **layer_kwargs + config=self.config, + mesh=self.mesh, + quant=self.quant, + model_mode=self.model_mode, + rngs=layer_rngs, + **layer_kwargs, ) + _, params, rest = nnx.split(layer, nnx.Param, ...) + return carry, (params, rest) - return layer + _, (stacked_params, stacked_rest) = jax.lax.scan(scan_body, None, rngs_state) - # Workaround for Deepseek MTP test failure. - # TODO: Handle this properly. - try: - forked_rngs = rngs.fork(split=length) + if scan_axis != 0: + stacked_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), stacked_params) - except: # pylint: disable=bare-except - pass + def _add_scan_metadata(state, axis): + def _update_leaf(leaf): + if hasattr(leaf, "replace") and hasattr(leaf, "value"): + replace_kwargs = {} + if hasattr(leaf, "get_metadata"): + replace_kwargs.update(leaf.get_metadata()) + + replace_kwargs[nnx.PARTITION_NAME] = metadata_axis_name + replace_kwargs["param_scan_axis"] = axis + + for key in ["sharding", "out_sharding", "kernel_axes", "sharding_names"]: + val = getattr(leaf, key, None) + if val is None and key in replace_kwargs: + val = replace_kwargs[key] - out_axes = nnx.StateAxes({nnx.Param: self.config.param_scan_axis, ...: 0}) - layers_vmapped = nnx.vmap( - create_layer_fn, - in_axes=0, - out_axes=out_axes, - axis_name="layers", - transform_metadata={nnx.PARTITION_NAME: "layers"}, - )(forked_rngs) + if val is not None: + if isinstance(val, str): + val = (val,) + if isinstance(val, tuple): + l = list(val) + # Safely insert the scan axis into the logical axes string + if metadata_axis_name not in l: + insert_idx = min(axis, len(l)) + l.insert(insert_idx, metadata_axis_name) + replace_kwargs[key] = tuple(l) - return layers_vmapped + return leaf.replace(**replace_kwargs) + return leaf + + # We must use a custom is_leaf to catch the VariableState instances + return jax.tree.map(_update_leaf, state, is_leaf=lambda x: hasattr(x, "replace") and hasattr(x, "value")) + + stacked_params = _add_scan_metadata(stacked_params, scan_axis) + stacked_rest = _add_scan_metadata(stacked_rest, 0) + + return nnx.merge(layer_graphdef, stacked_params, stacked_rest) def _apply_layer_with_remat(self, layer: nnx.Module, y: jax.Array, policy: Any, prevent_cse: bool, **kwargs): """Helper to cleanly apply jax.checkpoint to a single unscanned layer or block.""" @@ -430,56 +550,52 @@ def pure_layer_fn(state_in, y_in): def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs): """Runs the layer stack using nnx.scan.""" + if length == 0: + return x_in, layers policy = self.get_remat_policy() prevent_cse = maxtext_utils.should_prevent_cse_in_remat(self.config) - graphdef, params, state = nnx.split( - layers, nnx.Param, ... - ) # state: the mutable state we carry (KV cache, RNGs, etc.) + graphdef, params, state = nnx.split(layers, nnx.Param, ...) scan_axis = self.config.param_scan_axis if scan_axis != 0: - # Move scan_axis to 0 so scan can iterate over it params = jax.tree.map(lambda x: jnp.moveaxis(x, scan_axis, 0), params) layer_cls = layers.__class__ sig = inspect.signature(layer_cls.__call__) valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} - layer_cls = layers.__class__ # Access the underlying class - sig = inspect.signature(layer_cls.__call__) - # Filter kwargs to only include keys that exist in the layer's signature - valid_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters or "kwargs" in sig.parameters} + def _extract_matching_state(template, full): + if isinstance(template, nnx.State): + return nnx.State({k: _extract_matching_state(v, full[k]) for k, v in template.items()}) + elif isinstance(template, dict): + return {k: _extract_matching_state(v, full[k]) for k, v in template.items()} + return full def layer_fn(carry, scanned_vars): - # Unpack the sliced variables for THIS layer current_params, current_state = scanned_vars if self.config.parameter_memory_host_offload: current_params = jax.tree.map(lambda x: jax.device_put(x, max_utils.device_space()), current_params) - # Merge using the SLICED state layer = nnx.merge(graphdef, current_params, current_state) - - # Run the layer (Filter kwargs if using the solution from previous turn) layer_out = layer(carry, *args, **valid_kwargs) - new_carry = layer_out[0] if isinstance(layer_out, tuple) else layer_out - # Extract the updated state to return it - # _, new_current_state = nnx.split(layer, nnx.Param, ...) - new_current_state = nnx.state(layer) + new_full_state = nnx.state(layer) + new_current_state = _extract_matching_state(current_state, new_full_state) + return new_carry, new_current_state layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse) - final_carry, scanned_state = jax.lax.scan(layer_fn, x_in, (params, state)) + final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state)) if scan_axis != 0: - scanned_params, scanned_other = scanned_state.split(nnx.Param, ...) - scanned_params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), scanned_params) - scanned_state = nnx.State.merge(scanned_params, scanned_other) + params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params) - return final_carry, nnx.merge(graphdef, scanned_state) + scanned_state = nnx.State.merge(params, scanned_other) + nnx.update(layers, scanned_state) + return final_carry, layers def get_decoder_layers(self): """Retrieves decoder layer classes based on config using a dictionary lookup.""" @@ -489,8 +605,6 @@ def get_scannable(normal_cls, scannable_cls): return [scannable_cls] if cfg.scan_layers else [normal_cls] def get_deepseek(): - if cfg.use_batch_split_schedule: - return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] layer_map = { @@ -501,6 +615,7 @@ def get_deepseek(): DecoderBlockType.GEMMA: [gemma.GemmaDecoderLayer], DecoderBlockType.GEMMA2: [gemma2.Gemma2DecoderLayer], DecoderBlockType.GEMMA3: [gemma3.Gemma3DecoderLayer], + DecoderBlockType.GEMMA4: get_scannable(gemma4.Gemma4DecoderLayer, gemma4.Gemma4ScannableBlock), DecoderBlockType.GPT3: [gpt3.Gpt3DecoderLayer], DecoderBlockType.QWEN3: [qwen3.Qwen3DecoderLayer], DecoderBlockType.QWEN3_MOE: [qwen3.Qwen3MoeDecoderLayer], @@ -543,12 +658,10 @@ def get_remat_policy(self): cfg = self.config if cfg.remat_policy != "none": if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): - # save all if cfg.remat_policy == "minimal_flash": max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") policy = self.minimal_policy(with_context=True) elif cfg.remat_policy == "minimal": - # save all except context policy = self.minimal_policy() elif cfg.remat_policy == "minimal_with_quantization": if cfg.scan_layers: @@ -609,7 +722,6 @@ def get_remat_policy(self): offload_dst="pinned_host", ) elif cfg.remat_policy == "minimal_offloaded": - # offload all except context policy = jax.checkpoint_policies.save_and_offload_only_these_names( names_which_can_be_saved=[], names_which_can_be_offloaded=[ @@ -651,6 +763,7 @@ def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): DecoderBlockType.GEMMA, DecoderBlockType.GEMMA2, DecoderBlockType.GEMMA3, + DecoderBlockType.GEMMA4, DecoderBlockType.QWEN3, DecoderBlockType.QWEN3_MOE, DecoderBlockType.GPT_OSS, @@ -666,7 +779,7 @@ def get_norm_layer(self, num_features: int, rngs: nnx.Rngs): ) elif self.config.decoder_block == DecoderBlockType.QWEN3_NEXT: return functools.partial( - normalizations.Qwen3NextRMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs + normalizations.RMSNorm, num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs ) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") @@ -678,11 +791,7 @@ def _apply_embedding( decoder_positions, deterministic, model_mode, - image_embeddings=None, - bidirectional_mask=None, - image_masks=None, - audio_embeddings=None, - audio_masks=None, + multimodal_input=None, ): """Applies token and positional embeddings to the input tokens.""" cfg = self.config @@ -690,35 +799,43 @@ def _apply_embedding( y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) # Merge the image embeddings with the text embeddings for multimodal models - if image_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in [ - "gemma3-4b", - "gemma3-12b", - "gemma3-27b", - "llama4-17b-16e", - "llama4-17b-128e", - "qwen3-omni-30b-a3b", - ]: - y = mm_utils.merge_mm_embeddings( - text_embeddings=y, - multimodal_embeddings=image_embeddings, - mask=bidirectional_mask, - token_masks=image_masks, - ) - # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed - else: - raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") - - if audio_embeddings is not None and cfg.use_audio: - if cfg.model_name in ["qwen3-omni-30b-a3b"]: - y = mm_utils.merge_mm_embeddings( - text_embeddings=y, - multimodal_embeddings=audio_embeddings, - mask=audio_masks, - token_masks=None, - ) - else: - raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}") + if multimodal_input is not None: + image_embeddings = multimodal_input.image_embeddings + bidirectional_mask = multimodal_input.bidirectional_mask + image_masks = multimodal_input.image_masks + audio_embeddings = multimodal_input.audio_embeddings + audio_masks = multimodal_input.audio_masks + + if image_embeddings is not None and cfg.use_multimodal: + if cfg.model_name in [ + "gemma3-4b", + "gemma3-12b", + "gemma3-27b", + "gemma4-26b", + "gemma4-31b", + "llama4-17b-16e", + "llama4-17b-128e", + "qwen3-omni-30b-a3b", + ]: + y = mm_utils.merge_mm_embeddings( + text_embeddings=y, + multimodal_embeddings=image_embeddings, + mask=bidirectional_mask, + token_masks=image_masks, + ) + else: + raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") + + if audio_embeddings is not None and cfg.use_audio: + if cfg.model_name in ["qwen3-omni-30b-a3b"]: + y = mm_utils.merge_mm_embeddings( + text_embeddings=y, + multimodal_embeddings=audio_embeddings, + mask=audio_masks, + token_masks=None, + ) + else: + raise ValueError(f"Unsupported model_name for audio: {cfg.model_name}") y = self.dropout(y, deterministic=deterministic) y = y.astype(cfg.dtype) @@ -736,7 +853,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): cfg = self.config if cfg.shard_mode == ShardMode.EXPLICIT: - norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length", "activation_embed")) + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) else: norm_out_sharding = None @@ -747,7 +864,7 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) else: out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length", "activation_vocab") + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") ) # [batch, length, emb_dim] -> [batch, length, vocab_size] @@ -781,39 +898,13 @@ def _build_linen_params(self, moe_stack: nnx.Module) -> dict: Bridges NNX to Linen by creating a dictionary that mimics the exact variable structure expected by `deepseek_batchsplit.fetch_weights`. """ + state_dict = nnx.state(moe_stack, nnx.Param) return { - "pre_self_attention_layer_norm": { - "scale": moe_stack.pre_self_attention_layer_norm.scale, - }, - "post_self_attention_layer_norm": { - "scale": moe_stack.post_self_attention_layer_norm.scale, - }, - "self_attention": { - "wq_a": {"kernel": moe_stack.self_attention.wq_a.kernel}, - "wq_b": {"kernel": moe_stack.self_attention.wq_b.kernel}, - "q_norm": {"scale": moe_stack.self_attention.q_norm.scale}, - "wkv_a": {"kernel": moe_stack.self_attention.wkv_a.kernel}, - "wkv_b": {"kernel": moe_stack.self_attention.wkv_b.kernel}, - "kv_norm": {"scale": moe_stack.self_attention.kv_norm.scale}, - "out": {"kernel": moe_stack.self_attention.out.kernel}, - }, - "DeepSeekMoeBlock_0": { - "MoeBlock_0": { - "gate": { - "kernel": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.gate.kernel, - "bias": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias, - }, - "wi_0": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wi_0, - "wi_1": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wi_1, - "wo": moe_stack.DeepSeekMoeBlock_0.MoeBlock_0.wo, - }, - "shared_experts": { - "wi_0": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wi_0.kernel}, - "wi_1": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wi_1.kernel}, - "wo": {"kernel": moe_stack.DeepSeekMoeBlock_0.shared_experts.wo.kernel}, - }, - }, + "pre_self_attention_layer_norm": state_dict["pre_self_attention_layer_norm"], + "post_self_attention_layer_norm": state_dict["post_self_attention_layer_norm"], + "self_attention": state_dict["self_attention"], + "DeepSeekMoeBlock_0": state_dict.get("moe_block", state_dict.get("DeepSeekMoeBlock_0")), } def _find_next_boundary(self, current_idx, end_idx, engram_indices): @@ -823,28 +914,18 @@ def _find_next_boundary(self, current_idx, end_idx, engram_indices): return min(end_idx, *next_engrams) return end_idx - def _apply_single_engram_layer(self, y, current_idx, layer_stack, *args, **kwargs): - """Applies a single, unscanned Engram layer by dynamically slicing the NNX state.""" - graphdef, state = nnx.split(layer_stack) + def _apply_single_engram_layer(self, y, layer_name, *args, **kwargs): + """Applies a single, unscanned Engram layer.""" + layer = getattr(self, layer_name) - # Slice the parameters for the current index (assuming scan axis is 0) - sliced_state = jax.tree.map(lambda x: x[current_idx], state) - single_layer = nnx.merge(graphdef, sliced_state) + decoder_input_tokens = kwargs.get("decoder_input_tokens") + layer_kwargs = kwargs.get("layer_kwargs", {}) - # Run the single layer - out = single_layer( - y, *args, decoder_input_tokens=kwargs.get("decoder_input_tokens"), **kwargs.get("layer_kwargs", {}) - ) - y = out[0] if isinstance(out, tuple) else out - - # Re-merge the updated state back into the specific slice of the stack - new_single_state = nnx.state(single_layer) - updated_state = jax.tree.map( - lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, jnp.expand_dims(new_s, axis=0), current_idx, axis=0), - state, - new_single_state, - ) - nnx.update(layer_stack, updated_state) + out = layer(y, *args, decoder_input_tokens=decoder_input_tokens, **layer_kwargs) + if isinstance(out, tuple): + y = out[0] + else: + y = out return y @@ -853,10 +934,15 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args scan_length = next_boundary - current_idx if scan_length > 0: graphdef, state = nnx.split(layer_stack) + params, rest = state.split(nnx.Param, ...) + scan_axis = self.config.param_scan_axis - # Slice the chunk state - chunk_state = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), state) - chunk_stack = nnx.merge(graphdef, chunk_state) + # Slice the chunk state along the correct axes + chunk_params = jax.tree.map( + lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=scan_axis), params + ) + chunk_rest = jax.tree.map(lambda x: jax.lax.dynamic_slice_in_dim(x, current_idx, scan_length, axis=0), rest) + chunk_stack = nnx.merge(graphdef, chunk_params, chunk_rest) # Apply sequentially y, chunk_stack = self._apply_layers_sequentially( @@ -864,24 +950,37 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args ) # Update the original stack state - new_chunk_state = nnx.state(chunk_stack) - updated_state = jax.tree.map( - lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), state, new_chunk_state + new_state = nnx.state(chunk_stack) + new_params, new_rest = new_state.split(nnx.Param, ...) + + updated_params = jax.tree.map( + lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=scan_axis), params, new_params ) - nnx.update(layer_stack, updated_state) + updated_rest = jax.tree.map( + lambda s, new_s: jax.lax.dynamic_update_slice_in_dim(s, new_s, current_idx, axis=0), rest, new_rest + ) + + nnx.update(layer_stack, updated_params, updated_rest) return y - def _apply_interleaved_scanned_layers(self, y, layer_stack, start_idx, end_idx, engram_indices, *args, **kwargs): + def _apply_interleaved_scanned_layers(self, y, layer_prefix, start_idx, end_idx, engram_indices, *args, **kwargs): """Applies a mix of scanned standard layers and unscanned Engram layers.""" current_idx = start_idx while current_idx < end_idx: if current_idx in engram_indices: - y = self._apply_single_engram_layer(y, current_idx, layer_stack, *args, **kwargs) + layer_name = f"{layer_prefix}_engram_{current_idx}" + y = self._apply_single_engram_layer(y, layer_name, *args, **kwargs) current_idx += 1 else: next_boundary = self._find_next_boundary(current_idx, end_idx, engram_indices) - y = self._apply_scanned_chunk(y, current_idx, next_boundary, layer_stack, *args, **kwargs) + chunk_name = f"{layer_prefix}_{current_idx}_{next_boundary - 1}" + chunk_stack = getattr(self, chunk_name) + scan_length = next_boundary - current_idx + + y, chunk_stack = self._apply_layers_sequentially( + chunk_stack, y, *args, length=scan_length, **kwargs.get("layer_kwargs", {}) + ) current_idx = next_boundary return y @@ -896,13 +995,9 @@ def __call__( previous_chunk=None, slot: None | int = None, page_state: None | page_manager.PageState = None, - bidirectional_mask: None | Any = None, - image_embeddings: None | jnp.ndarray = None, - image_masks: None | jnp.ndarray = None, + multimodal_input: None | Any = None, kv_caches: list[jax.Array] | None = None, attention_metadata=None, - audio_embeddings: None | jnp.ndarray = None, - audio_masks: None | jnp.ndarray = None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, ): cfg = self.config @@ -917,11 +1012,7 @@ def __call__( decoder_positions, deterministic, model_mode, - image_embeddings, - bidirectional_mask, - image_masks, - audio_embeddings, - audio_masks, + multimodal_input=multimodal_input, ) mhc_expand, mhc_reduce = mhc.get_functions(cfg.mhc_expansion_rate) @@ -932,7 +1023,10 @@ def __call__( layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) layer_kwargs = {} - if cfg.decoder_block == DecoderBlockType.GEMMA3: + # Extract the bidirectional mask locally for layer configurations + bidirectional_mask = multimodal_input.bidirectional_mask if multimodal_input is not None else None + + if cfg.decoder_block in (DecoderBlockType.GEMMA3, DecoderBlockType.GEMMA4): layer_kwargs["bidirectional_mask"] = bidirectional_mask if attention_metadata is not None: @@ -953,15 +1047,15 @@ def __call__( } y = self._apply_interleaved_scanned_layers( - y, self.dense_layers, 0, cfg.first_num_dense_layers, cfg.engram_layers, *layer_args, **common_kwargs + y, "dense_layers", 0, cfg.first_num_dense_layers, cfg.engram_layers, *layer_args, **common_kwargs ) y = self._apply_interleaved_scanned_layers( y, - self.moe_layer, - 0, - (cfg.num_decoder_layers - cfg.first_num_dense_layers), - [e - cfg.first_num_dense_layers for e in cfg.engram_layers], + "moe_layers", + cfg.first_num_dense_layers, + cfg.num_decoder_layers, + cfg.engram_layers, *layer_args, **common_kwargs, ) @@ -973,19 +1067,34 @@ def __call__( num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers if cfg.use_batch_split_schedule: - mock_params = self._build_linen_params(self.moe_layer) - - y = deepseek_batchsplit.scan_batch_split_layers( - y, - mock_params, - decoder_positions, - mesh=self.mesh, - cfg=cfg, - num_layers=num_moe, - ) + policy = self.get_remat_policy() + mock_params = self._build_linen_params(self.moe_layers) + + if cfg.use_qwix_quantization: + y = deepseek_batchsplit_fp8.scan_batch_split_layers( + y, + mock_params, + decoder_positions, + decoder_segment_ids, + model_mode=model_mode, + mesh=self.mesh, + quant=self.quant, + cfg=cfg, + policy=policy, + ) + else: + # bf16 code path + y = deepseek_batchsplit.scan_batch_split_layers( + y, + mock_params, + decoder_positions, + mesh=self.mesh, + cfg=cfg, + num_layers=num_moe, + ) else: - y, self.moe_layer = self._apply_layers_sequentially( - self.moe_layer, y, *layer_args, length=num_moe, **layer_kwargs + y, self.moe_layers = self._apply_layers_sequentially( + self.moe_layers, y, *layer_args, length=num_moe, **layer_kwargs ) elif self.is_gemma3: y = self._apply_gemma3_scanned_blocks( @@ -999,9 +1108,24 @@ def __call__( page_state, slot, ) + elif self.is_gemma4: + y = self._apply_gemma4_scanned_blocks( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ) else: scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) - y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + if scan_length > 0: + y, self.layers = self._apply_layers_sequentially( + self.layers, y, *layer_args, length=scan_length, **layer_kwargs + ) else: prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) @@ -1019,7 +1143,16 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): for lyr, layer in enumerate(self.layers): graphdef, state = nnx.split(layer) - kv_cache = kv_caches[lyr] if kv_caches is not None else None + if kv_caches is not None: + if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: + if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: + kv_cache = (kv_caches["key_cache"][lyr], kv_caches["value_cache"][lyr]) + else: + kv_cache = None + else: + kv_cache = kv_caches[lyr] + else: + kv_cache = None input_tokens = decoder_input_tokens if cfg.engram_layers else None if input_tokens is not None: @@ -1029,7 +1162,12 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): nnx.update(layer, new_state) if kv_caches is not None and kv_cache is not None: - kv_caches[lyr] = kv_cache + if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: + if (lyr + 1) % cfg.inhomogeneous_layer_cycle_interval == 0: + kv_caches["key_cache"][lyr] = kv_cache[0] + kv_caches["value_cache"][lyr] = kv_cache[1] + else: + kv_caches[lyr] = kv_cache if deepstack_visual_embeds is not None and lyr < len(deepstack_visual_embeds): visual_embeds = deepstack_visual_embeds[lyr] @@ -1049,9 +1187,14 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in): if cfg.attention == "vllm_rpa": logits = None + # When in the Indexer Dense Warm-up stage, skip the expensive output head projection + # for efficiency, as the main model is frozen and the LM loss is not needed. + elif (cfg.use_indexer and not cfg.indexer_sparse_training) and self.model_mode == MODEL_MODE_TRAIN: + logits = None + # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + elif cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None self.sow(nnx.Intermediate, "hidden_states", hidden_state) @@ -1108,6 +1251,54 @@ def pure_gemma_fn(graphdef, state_in, y_in): return y + def _apply_gemma4_scanned_blocks( + self, + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + bidirectional_mask, + previous_chunk, + page_state, + slot, + ): + """Applies Gemma4 scanned decoder blocks, handling main scan and remainders.""" + + cfg = self.config + + # Define the repeating pattern length and calculate how many full blocks to scan + attention_pattern_length = len(gemma4.GEMMA4_ATTENTION_PATTERN) + scan_length = cfg.num_decoder_layers // attention_pattern_length + + layer_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + layer_kwargs = {"bidirectional_mask": bidirectional_mask} + + # Apply the main scan over the full blocks + if scan_length > 0: + y, self.layers = self._apply_layers_sequentially(self.layers, y, *layer_args, length=scan_length, **layer_kwargs) + + # Apply any remaining layers that did not fit into a full scanned block + num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length + if num_remaining_layers > 0: + policy = self.get_remat_policy() + prevent_cse = maxtext_utils.should_prevent_cse_in_remat(cfg) + + def pure_gemma_fn(graphdef, state_in, y_in): + merged_layer = nnx.merge(graphdef, state_in) + out_y, _ = merged_layer( + y_in, *layer_args, previous_chunk=previous_chunk, page_state=page_state, slot=slot, **layer_kwargs + ) + return out_y, nnx.state(merged_layer) + + checkpointed_gemma_fn = jax.checkpoint(pure_gemma_fn, policy=policy, prevent_cse=prevent_cse) + + graphdef, state = nnx.split(self.layers_remainder) + y, new_state = checkpointed_gemma_fn(graphdef, state, y) + nnx.update(self.layers_remainder, new_state) + + return y + def decoder_as_linen( config: Config, @@ -1116,7 +1307,7 @@ def decoder_as_linen( model_mode: str, quant: None | Quant = None, ): - """Creates a Decoder module.""" + """Creates a Decoder module""" module = nnx_wrappers.to_linen( NNXDecoder, config=config, diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index f5dd4e6cc3..1b0d4b4cd3 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -33,7 +33,7 @@ from maxtext.layers.decoders import Decoder from maxtext.layers.embeddings import Embed, embed_as_linen from maxtext.layers.encoders import AudioEncoder, VisionEncoder, audio_encoder_as_linen, vision_encoder_as_linen -from maxtext.layers.multi_token_prediction import multi_token_prediction_block_as_linen +from maxtext.layers.multi_token_prediction import MultiTokenPredictionBlock, multi_token_prediction_block_as_linen from maxtext.layers.quantizations import AqtQuantization as Quant from maxtext.multimodal import processor as mm_processor from maxtext.utils import max_utils @@ -386,25 +386,12 @@ def __init__( # For MTP, we use the DecoderLayer blueprint to ensure architectural consistency. # By convention, this is the last layer in the list. mtp_layer = layer_types[-1] - mtp_block_linen = multi_token_prediction_block_as_linen( + self.mtp_block = MultiTokenPredictionBlock( config=self.config, mesh=self.mesh, transformer_layer_module=mtp_layer, decoder=self.decoder, rngs=rngs, - name="mtp_block", - ) - self.mtp_block = nnx_wrappers.ToNNX(mtp_block_linen, rngs=rngs) - - self.mtp_block.lazy_init( - shared_embedding=self.token_embedder, - main_hidden_state=jnp.ones((1, 1, self.config.emb_dim), dtype=self.config.dtype), - input_ids=jnp.ones((1, 1), dtype=jnp.int32), - target_ids=jnp.ones((1, 1), dtype=jnp.int32), - target_mask=jnp.ones((1, 1), dtype=jnp.int32), - position_ids=jnp.ones((1, 1), dtype=jnp.int32), - decoder_segment_ids=jnp.ones((1, 1), dtype=jnp.int32), - deterministic=True, ) def no_op(self, *args, **kwargs): diff --git a/tests/unit/nnx_decoders_test.py b/tests/unit/nnx_decoders_test.py index 8979440732..acff8afe23 100644 --- a/tests/unit/nnx_decoders_test.py +++ b/tests/unit/nnx_decoders_test.py @@ -31,7 +31,7 @@ from flax import nnx from jax.sharding import Mesh -from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN, DecoderBlockType +from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, DecoderBlockType from maxtext.configs import pyconfig from maxtext.layers import linears from maxtext.layers.attentions import Attention @@ -65,13 +65,8 @@ def _make_config(**overrides): """Return a pyconfig Config object suitable for unit tests.""" extra_args = get_decoupled_parallelism_overrides() - return pyconfig.initialize( - [sys.argv[0], get_test_config_path()], - **_BASE_CONFIG, - **extra_args, - **overrides, - override_model_config=True, - ) + merged = {**_BASE_CONFIG, **extra_args, **overrides} + return pyconfig.initialize([sys.argv[0], get_test_config_path()], override_model_config=True, **merged) def _make_mesh(cfg): @@ -87,6 +82,7 @@ def _make_mesh(cfg): class TestDeepstackProcess(unittest.TestCase): """Tests for the deepstack_process pure function.""" + # pylint: disable=too-many-positional-arguments def _make_inputs(self, batch=2, seq_len=8, hidden_dim=16, num_visual=3, seed=0): key = jax.random.PRNGKey(seed) k1, k2 = jax.random.split(key) @@ -188,9 +184,9 @@ def setUp(self): self.mesh = _make_mesh(self.cfg) self.rng = jax.random.PRNGKey(0) - def _make_layer(self, model_mode=MODEL_MODE_TRAIN): + def _make_layer(self, model_mode=MODEL_MODE_TRAIN, config=None): return NNXDecoderLayer( - config=self.cfg, + config=config if config is not None else self.cfg, mesh=self.mesh, model_mode=model_mode, rngs=nnx.Rngs(params=0, dropout=1), @@ -228,16 +224,60 @@ def test_forward_output_shape_train(self): """Forward pass output shape matches input shape in train mode.""" layer = self._make_layer(MODEL_MODE_TRAIN) inputs, segment_ids, positions = self._make_inputs() - out, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN) + out, _ = layer( + inputs, + segment_ids, + positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) self.assertEqual(out.shape, inputs.shape) def test_forward_output_dtype(self): """Output dtype matches config dtype.""" layer = self._make_layer() inputs, segment_ids, positions = self._make_inputs() - out, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN) + out, _ = layer( + inputs, + segment_ids, + positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) self.assertEqual(out.dtype, self.cfg.dtype) + def test_forward_prefill_mode(self): + """Test forward pass in prefill mode.""" + layer = self._make_layer(MODEL_MODE_PREFILL) + inputs, segment_ids, positions = self._make_inputs() + out, _ = layer( + inputs, + segment_ids, + positions, + deterministic=True, + model_mode=MODEL_MODE_PREFILL, + ) + self.assertEqual(out.shape, inputs.shape) + + def test_record_metrics(self): + """Test recording intermediate activation metrics.""" + cfg = _make_config(record_internal_nn_metrics=1) + layer = self._make_layer(MODEL_MODE_TRAIN, config=cfg) + inputs, segment_ids, positions = self._make_inputs() + + # Use nnx.capture to retrieve sown variables + _, state = nnx.capture(layer, nnx.Intermediate)( + inputs, + segment_ids, + positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + metrics_keys = state.keys() + self.assertIn("activation_mean", metrics_keys) + self.assertIn("activation_stdev", metrics_keys) + self.assertIn("activation_fraction_zero", metrics_keys) + def test_forward_kv_cache_is_none_when_scan_layers_false(self): """kv_cache return value is not None when scan_layers=False (non-scan returns cache).""" # With scan_layers=False the layer returns (output, kv_cache). @@ -245,7 +285,13 @@ def test_forward_kv_cache_is_none_when_scan_layers_false(self): # verify the call doesn't raise and returns a 2-tuple. layer = self._make_layer() inputs, segment_ids, positions = self._make_inputs() - result = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN) + result = layer( + inputs, + segment_ids, + positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) self.assertIsInstance(result, tuple) self.assertEqual(len(result), 2) @@ -253,8 +299,20 @@ def test_forward_deterministic_and_stochastic_consistent_shape(self): """Output shape is the same regardless of the deterministic flag.""" layer = self._make_layer() inputs, segment_ids, positions = self._make_inputs() - out_det, _ = layer(inputs, segment_ids, positions, deterministic=True, model_mode=MODEL_MODE_TRAIN) - out_stoch, _ = layer(inputs, segment_ids, positions, deterministic=False, model_mode=MODEL_MODE_TRAIN) + out_det, _ = layer( + inputs, + segment_ids, + positions, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + out_stoch, _ = layer( + inputs, + segment_ids, + positions, + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + ) self.assertEqual(out_det.shape, out_stoch.shape) @@ -476,7 +534,11 @@ def test_logits_shape(self): deterministic=True, model_mode=MODEL_MODE_TRAIN, ) - expected = (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.vocab_size) + expected = ( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.vocab_size, + ) self.assertEqual(logits.shape, expected) def test_hidden_state_shape(self): @@ -491,7 +553,11 @@ def test_hidden_state_shape(self): deterministic=True, model_mode=MODEL_MODE_TRAIN, ) - expected = (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.emb_dim) + expected = ( + cfg.global_batch_size_to_train_on, + cfg.max_target_length, + cfg.emb_dim, + ) self.assertEqual(hidden_state.shape, expected) def test_logits_are_finite(self): @@ -532,6 +598,101 @@ def test_different_random_seeds_produce_different_logits(self): logits2, _, _ = decoder2(shared_emb2, ids, positions, **common_kwargs) self.assertFalse(jnp.allclose(logits1, logits2)) + def test_scan_layers(self): + """Test NNXDecoder with scan_layers=True.""" + cfg = _make_config(scan_layers=True) + rngs = nnx.Rngs(params=0, dropout=1) + decoder = NNXDecoder( + config=cfg, + mesh=self.mesh, + model_mode=MODEL_MODE_TRAIN, + rngs=rngs, + ) + shared_embedding = Embed( + num_embeddings=cfg.vocab_size, + num_features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=cfg, + mesh=self.mesh, + rngs=rngs, + ) + + batch = cfg.global_batch_size_to_train_on + seq_len = cfg.max_target_length + ids = jax.random.randint(self.rng, (batch, seq_len), 0, cfg.vocab_size) + segment_ids = jnp.full((batch, seq_len), DECODING_ACTIVE_SEQUENCE_INDICATOR) + positions = jnp.broadcast_to(jnp.arange(seq_len)[None], (batch, seq_len)) + + logits, _, _ = decoder( + shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + self.assertEqual(logits.shape, (batch, seq_len, cfg.vocab_size)) + if __name__ == "__main__": unittest.main() + + +class TestNNXDecoderDeepseekAndGemma4(unittest.TestCase): + """Tests for Deepseek and Gemma4 specific decoder logic.""" + + def setUp(self): + super().setUp() + self.cfg = _make_config() + self.mesh = _make_mesh(self.cfg) + self.rng = jax.random.PRNGKey(0) + self.rngs = nnx.Rngs(params=0, dropout=1) + + def _make_token_inputs(self, cfg): + batch = cfg.global_batch_size_to_train_on + seq_len = cfg.max_target_length + ids = jax.random.randint(self.rng, (batch, seq_len), 0, cfg.vocab_size) + segment_ids = jnp.full((batch, seq_len), DECODING_ACTIVE_SEQUENCE_INDICATOR) + positions = jnp.broadcast_to(jnp.arange(seq_len)[None], (batch, seq_len)) + return ids, segment_ids, positions + + def _make_shared_embedding(self, cfg): + return Embed( + num_embeddings=cfg.vocab_size, + num_features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=cfg, + mesh=self.mesh, + rngs=self.rngs, + ) + + def test_gemma4_scanned_layers(self): + """Test NNXDecoder with gemma4 block and scan_layers=True.""" + cfg = _make_config( + decoder_block="gemma4", + scan_layers=True, + num_decoder_layers=3, # Not a multiple of the pattern length (which is usually larger) to test remainder logic + ) + decoder = NNXDecoder( + config=cfg, + mesh=self.mesh, + model_mode=MODEL_MODE_TRAIN, + rngs=self.rngs, + ) + shared_embedding = self._make_shared_embedding(cfg) + ids, segment_ids, positions = self._make_token_inputs(cfg) + + logits, _, _ = decoder( + shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + ) + self.assertEqual( + logits.shape, + (cfg.global_batch_size_to_train_on, cfg.max_target_length, cfg.vocab_size), + ) diff --git a/tests/unit/tiling_test.py b/tests/unit/tiling_test.py index 58b688634d..6ed33c3c67 100644 --- a/tests/unit/tiling_test.py +++ b/tests/unit/tiling_test.py @@ -209,6 +209,8 @@ def test_vocab_tiling_gradient_with_z_loss(self): num_vocab_tiling=1, z_loss_multiplier=1e-4, # Enable z-loss ) + if getattr(cfg_non_tiling, "enable_nnx", False): + pytest.skip("We currently don't support vocab tiling on NNX module.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -275,6 +277,8 @@ def test_vocab_tiling_gradient_non_tied_embedding(self): matmul_precision="high", num_vocab_tiling=1, ) + if getattr(cfg_non_tiling, "enable_nnx", False): + pytest.skip("We currently don't support vocab tiling on NNX module.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -340,6 +344,8 @@ def test_vocab_tiling_gradient_tied_embedding(self): num_vocab_tiling=1, ) + if getattr(cfg_non_tiling, "enable_nnx", False): + pytest.skip("We currently don't support vocab tiling on NNX module.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -401,6 +407,8 @@ def test_vocab_tiling_gradient_data_parallelism(self): matmul_precision="high", num_vocab_tiling=1, ) + if getattr(cfg_non_tiling, "enable_nnx", False): + pytest.skip("We currently don't support vocab tiling on NNX module.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -465,6 +473,8 @@ def test_vocab_tiling_gradient_tensor_parallelism(self): matmul_precision="high", num_vocab_tiling=1, ) + if getattr(cfg_non_tiling, "enable_nnx", False): + pytest.skip("We currently don't support vocab tiling on NNX module.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) @@ -531,6 +541,8 @@ def test_vocab_tiling_gradient_context_parallelism(self): matmul_precision="high", num_vocab_tiling=1, ) + if getattr(cfg_non_tiling, "enable_nnx", False): + pytest.skip("We currently don't support vocab tiling on NNX module.") quant_non_tiling = quantizations.configure_quantization(cfg_non_tiling) devices_array_non_tiling = maxtext_utils.create_device_mesh(cfg_non_tiling) mesh_non_tiling = Mesh(devices_array_non_tiling, cfg_non_tiling.mesh_axes) diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 273708defa..11b5623cb2 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -26,7 +26,9 @@ import pytest import transformers + from maxtext.checkpoint_conversion.utils.hf_model_configs import DeepseekV32Config +from maxtext.configs import pyconfig from maxtext.trainers.pre_train.train_compile import main as train_compile_main from tests.utils.test_helpers import get_test_config_path @@ -504,6 +506,10 @@ def test_moe_dense_int8(self): @pytest.mark.cpu_only def test_moe_pp_bf16(self): + cfg = pyconfig.initialize([None, get_test_config_path()]) + if getattr(cfg, "pure_nnx_decoder", False): + pytest.skip("Pipeline parallelism not supported for pure_nnx_decoder=True") + temp_dir = gettempdir() compiled_trainstep_file = os.path.join(temp_dir, "test_moe_pp_bf16.pickle") train_compile_main( @@ -601,6 +607,10 @@ def test_moe_deepseek_with_device_limit(self): @pytest.mark.cpu_only def test_moe_deepseek_pipeline_subset(self): + cfg = pyconfig.initialize([None, get_test_config_path()]) + if getattr(cfg, "pure_nnx_decoder", False): + pytest.skip("Pipeline parallelism not supported for pure_nnx_decoder=True") + compiled_trainstep_file = "/tmp/test_moe_deepseek_pipeline_subset.pickle" train_compile_main( ( @@ -624,6 +634,10 @@ def test_moe_deepseek_pipeline_subset(self): @pytest.mark.cpu_only def test_pipeline_subset(self): + cfg = pyconfig.initialize([None, get_test_config_path()]) + if getattr(cfg, "pure_nnx_decoder", False): + pytest.skip("Test not supported for pure_nnx_decoder=True") + compiled_trainstep_file = "/tmp/test_pipeline_subset.pickle" train_compile_main( ( @@ -904,6 +918,10 @@ def test_engram_integration(self): @pytest.mark.cpu_only def test_circular_pipeline_ag_per_repeat_ep_ds(self): + cfg = pyconfig.initialize([None, get_test_config_path()]) + if getattr(cfg, "pure_nnx_decoder", False): + pytest.skip("Pipeline parallelism not supported for pure_nnx_decoder=True") + temp_dir = gettempdir() compiled_trainstep_file = os.path.join(temp_dir, "test_circular_pipeline_ag_per_repeat_ep_ds.pickle") train_compile_main( @@ -959,6 +977,10 @@ def test_qk_clip(self): @pytest.mark.cpu_only def test_vocab_tiling_bf16(self): """test vocab_tiling when weight_dtype=bfloat16""" + cfg = pyconfig.initialize([None, get_test_config_path()]) + if getattr(cfg, "enable_nnx", False): + pytest.skip("Vocab tiling not supported on NNX.") + compiled_trainstep_file = "/tmp/test_vocab_tiling_bf16.pickle" train_compile_main( (