diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 75535fae29..e67329ecbd 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -20,6 +20,7 @@ from absl import flags import datetime from etils import epath +from flax import nnx from flax.training import train_state import jax from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE @@ -536,7 +537,7 @@ def load_state_if_possible( load_parameters_from_path: str, load_full_state_from_path: str, checkpoint_storage_concurrent_gb: int, - abstract_unboxed_pre_state: train_state.TrainState, + abstract_unboxed_pre_state: train_state.TrainState | nnx.State, enable_single_replica_ckpt_restoring: bool | None = False, dataset_type: str | None = "tfds", step: int = -1, # -1 means latest @@ -604,9 +605,14 @@ def map_to_pspec(data): ) ocp.type_handlers.register_type_handler(jax.Array, array_handler, override=True) - restore_args = jax.tree_util.tree_map(map_to_pspec, abstract_unboxed_pre_state) + # Convert nnx.State to pure dict to match how checkpoints are saved for NNX + restore_target = abstract_unboxed_pre_state + if isinstance(abstract_unboxed_pre_state, nnx.State): + restore_target = abstract_unboxed_pre_state.to_pure_dict() + + restore_args = jax.tree_util.tree_map(map_to_pspec, restore_target) checkpoint_args = ocp.args.PyTreeRestore( - item=abstract_unboxed_pre_state, + item=restore_target, restore_args=restore_args, partial_restore=True, ) @@ -620,9 +626,7 @@ def map_to_pspec(data): (EmergencyCheckpointManager, EmergencyReplicatorCheckpointManager), ): return ( - checkpoint_manager.restore( - step, args=Composite(state=checkpoint_args) - ).state, + checkpoint_manager.restore(step, args=Composite(state=checkpoint_args)).state, None, ) # Case 2: Matches if dataset type is "grain" and the data iterator is not a @@ -647,9 +651,14 @@ def map_to_pspec(data): return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None) if load_parameters_from_path != "": + if isinstance(abstract_unboxed_pre_state, nnx.State): + _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) + else: + params = abstract_unboxed_pre_state.params + restored_params = load_params_from_path( load_parameters_from_path, - abstract_unboxed_pre_state.params, + params, checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, @@ -741,7 +750,18 @@ def maybe_save_checkpoint(checkpoint_manager, state, config, data_iterator, step # Determine the effective step for saving a checkpoint. # If 'step' is not provided, this call is for a potential final checkpoint # and use the last completed step from the state. - actual_step = (int(state.step) - 1) if step is None else int(step) + if step is not None: + actual_step = int(step) + else: + if config.pure_nnx: + actual_step = int(state.optimizer.step) - 1 + else: + # Linen TrainState has .step attribute + actual_step = int(state.step) - 1 + + if config.pure_nnx: + # Convert nnx.State to dict. + state = state.to_pure_dict() # Determine if a checkpoint save should be forced, overriding the usual `config.checkpoint_period` logic. # This occurs if this function was called: diff --git a/src/maxtext/layers/nnx_decoders.py b/src/maxtext/layers/nnx_decoders.py index 3c8a601201..59c3b6cc06 100644 --- a/src/maxtext/layers/nnx_decoders.py +++ b/src/maxtext/layers/nnx_decoders.py @@ -35,6 +35,7 @@ MODEL_MODE_TRAIN, Config, DecoderBlockType, + MultimodalInput, ShardMode, ) from maxtext.inference import page_manager @@ -904,7 +905,14 @@ def __call__( audio_embeddings: None | jnp.ndarray = None, audio_masks: None | jnp.ndarray = None, deepstack_visual_embeds: None | list[jnp.ndarray] = None, + multimodal_input: None | MultimodalInput = None, ): + if multimodal_input is not None: + image_embeddings = multimodal_input.image_embeddings + image_masks = multimodal_input.image_masks + audio_embeddings = multimodal_input.audio_embeddings + audio_masks = multimodal_input.audio_masks + bidirectional_mask = multimodal_input.bidirectional_mask cfg = self.config assert decoder_input_tokens.ndim == 2 # [batch, len] diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 912157f323..c06cdb87ca 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -35,8 +35,9 @@ import jax import jax.numpy as jnp +from jax.sharding import NamedSharding -from flax import linen as nn +from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning from maxtext.configs import pyconfig @@ -67,6 +68,7 @@ from maxtext.utils import maxtext_utils from maxtext.utils import qk_clip_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx from maxtext.utils import train_utils from maxtext.utils.gradient_accumulation import gradient_accumulation_loss_and_grad from maxtext.utils.vocabulary_tiling import vocab_tiling_linen_loss @@ -87,17 +89,15 @@ def get_first_step(model, state): # ----------------------------------------------------------------------------- -def loss_fn( - model, config, data, dropout_rng, params, sparsity_state=None, is_train=True -): +def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_train=True): """loss_fn for both train and eval. Args: - model: A nn.Module + model: A nn.Module (Linen) or nnx.Module (NNX). config: Config of parameters data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout - params: Model params + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. + params: Model params (Linen); unused for NNX (params are part of the model). is_train: True for train_step and False for eval_step Returns: @@ -121,9 +121,7 @@ def loss_fn( # make its specific collection mutable so the MTPBlock can sow into it. if config.mtp_eval_target_module > 0 and not is_train: mutable_collections.append("mtp_acceptance") - sparsity_enabled = ( - is_train and config.weight_sparsity_n and config.weight_sparsity_m - ) + sparsity_enabled = is_train and config.weight_sparsity_n and config.weight_sparsity_m if sparsity_enabled: mutable_collections.append("batch_stats") if isinstance(model, nn.Module): @@ -143,9 +141,7 @@ def loss_fn( data["inputs_position"], decoder_segment_ids=data["inputs_segmentation"], encoder_images=data["images"] if config.use_multimodal else None, - encoder_image_masks=data["image_masks"] - if config.use_multimodal and "image_masks" in data - else None, + encoder_image_masks=data["image_masks"] if config.use_multimodal and "image_masks" in data else None, enable_dropout=config.enable_dropout if is_train else False, rngs={"dropout": rng1, "params": aqt_rng}, mutable=mutable_collections, @@ -188,7 +184,7 @@ def loss_fn( xent_sum = jnp.sum(xent) total_z_loss = jnp.sum(z_loss) else: - # Flax NNX model + # Flax NNX model: logits = model( decoder_input_tokens=data["inputs"], decoder_positions=data["inputs_position"], @@ -199,7 +195,11 @@ def loss_fn( decoder_target_tokens=data["targets"], decoder_target_mask=data["targets_segmentation"], ) - intermediate_outputs = {} + # Capture NNX intermediates (MoE losses, hidden states, etc.) + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() + + if config.num_vocab_tiling > 1: + raise NotImplementedError("Vocab tiling for NNX modules has not been implemented.") if (config.use_indexer and not config.indexer_sparse_training) and is_train: # In Dense Warm-up stage, we skip main model loss calculation for efficiency. @@ -286,83 +286,116 @@ def loss_fn( "indexer_loss": indexer_loss, "moe_bias_updates": moe_bias_updates, "mtp_loss": mtp_loss, - "batch_stats": ( - intermediate_outputs.get("batch_stats", None) - if hasattr(intermediate_outputs, "get") - else None - ), + "batch_stats": (intermediate_outputs.get("batch_stats", None) if hasattr(intermediate_outputs, "get") else None), } return loss, aux -def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """ +def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng=None): + """Training step for both Linen and NNX models. Args: - model: A nn.Module - state: A pytree of the current state of the model - data: Batch of data to apply to the model - dropout_rng: A key to use to generate rng for dropout + model: A nn.Module (Linen) or nnx.GraphDef of the TrainStateNNX (NNX). + config: Hyperparameters. + state_mesh_shardings: PyTree of PartitionSpecs for the train state. + params_shardings: PyTree of PartitionSpecs for model parameters, used for gradient accumulation. + state: Linen TrainState or NNX pure State. + data: Training data batch. + dropout_rng: A key to use to generate rng for dropout (Linen); unused for NNX. Returns: - new_state: Same format as state. + new_state: Updated Linen TrainState or NNX pure State. metrics: Dictionary of model metrics such as loss, training rate, etc. - rng2: A new rng key that can be used in future calls. - """ - reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = ( - [], - [], - [], - loss_fn, - ) - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn + # --- Per-path initialization --- + if isinstance(model, nn.Module): + reference_params, reference_params_sharding, extra_dpo_args, _loss_fn = [], [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + state_mesh_shardings, reference_params_sharding = _split_dpo_state(state_mesh_shardings) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn + params = state.params + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = _loss_fn, model, params, dropout_rng, extra_dpo_args + else: + if config.use_dpo: + raise NotImplementedError("DPO for NNX modules has not been implemented.") + state = nnx.merge(model, state) # reconstruct TrainStateNNX + ga_fn, ga_model, ga_params, ga_rng, ga_dpo = loss_fn, state.model, None, None, [] - params = state.params + # --- Gradient computation --- if config.gradient_accumulation_steps > 1: loss, aux, raw_grads = gradient_accumulation_loss_and_grad( - _loss_fn, + ga_fn, config, - model, - params, + ga_model, + ga_params, params_shardings, data, - dropout_rng, - extra_dpo_args, + ga_rng, + ga_dpo, ) else: - if config.optimizer_memory_host_offload: - if config.use_dpo: + if isinstance(model, nn.Module): + if config.optimizer_memory_host_offload and config.use_dpo: reference_params = jax.device_put( reference_params, max_utils.with_memory_kind(reference_params_sharding, "device"), ) extra_dpo_args = [reference_params] - if config.shard_optimizer_over_data: - params = jax.tree.map( - functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), - params, - params_shardings, + if config.shard_optimizer_over_data: + params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + params, + params_shardings, + ) + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + pure_params = params["params"] if sparsity_enabled else params + batch_stats = params.get("batch_stats", {}) + + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + (loss, aux), raw_grads = grad_func( + model, + config, + data, + dropout_rng, + pure_params, + *extra_dpo_args, + sparsity_state=batch_stats, + is_train=True, ) - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - pure_params = params["params"] if sparsity_enabled else params - batch_stats = params.get("batch_stats", {}) + else: + model_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + if config.parameter_memory_host_offload: + # Params are kept on host (pinned_host) in in_shardings. Move only Param + # variables to device before the forward/backward pass so that all dot_general + # operands share the same memory space (XLA on GPU requires this). + # Using params_shardings (Param-only) avoids Shardy rank mismatches that + # occur when applying PartitionSpec() (rank-0 in SDY) to rank-1 RNG key tensors. + device_param_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + params_shardings, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + curr_params = jax.device_put(curr_params, device_param_shardings) + nnx.update(state.model, curr_params) # ensure state.model has device params for optimizer update + if config.shard_optimizer_over_data: + curr_params = jax.tree.map( + functools.partial(sharding.maybe_shard_with_name, shard_mode=config.shard_mode), + curr_params, + params_shardings, + ) + nnx.update(state.model, curr_params) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) - (loss, aux), raw_grads = grad_func( - model, - config, - data, - dropout_rng, - pure_params, - *extra_dpo_args, - sparsity_state=batch_stats, - is_train=True, - ) + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(model_graphdef, param, rest, copy=True) + loss, aux = loss_fn(local_model, config, data, None, None, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) raw_grads = jax.tree_util.tree_map( lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, @@ -373,6 +406,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat raw_grads, max_utils.with_memory_kind(params_shardings, "device"), ) + + # Extract aux fields into locals intermediate_outputs = aux["intermediate_outputs"] xent_sum = aux["xent_sum"] total_weights = aux["total_weights"] @@ -382,69 +417,90 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat moe_bias_updates = aux.get("moe_bias_updates") mtp_loss = aux.get("mtp_loss", 0.0) - if config.gradient_clipping_threshold > 0: - grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) - else: - grads = raw_grads - - if config.optimizer_memory_host_offload: - state = state.replace( - opt_state=jax.device_put( - state.opt_state, - jax.tree_util.tree_map( - lambda x: x.with_memory_kind(kind="device"), - state_mesh_shardings.opt_state, - ), - ) - ) - # Move all parameters to device before optimizer update - if config.parameter_memory_host_offload: - max_logging.log("\nMoving all parameters to device before optimizer update") - - def move(path, value): - max_logging.log(f"train.py: Moving f{path} to device") - return value.with_memory_kind(kind="device") - - state = state.replace( - params=jax.device_put( - state.params, - jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), - ) - ) - # Re-wrap grads to match state.params structure if it's a dict of collections - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if sparsity_enabled: - full_grads = {"params": grads} - if sparsity_enabled and "batch_stats" in state.params: - batch_stats_grads = jax.tree_util.tree_map( - jnp.zeros_like, state.params.get("batch_stats", {}) + if isinstance(model, nn.Module): + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + state = state.replace( + opt_state=jax.device_put( + state.opt_state, + jax.tree_util.tree_map( + lambda x: x.with_memory_kind(kind="device"), + state_mesh_shardings.opt_state, + ), + ) ) - full_grads["batch_stats"] = batch_stats_grads - full_grads = max_utils.unbox_logicallypartioned(full_grads) - else: - full_grads = grads - - if getattr(config, "skip_step_on_spikes", False): - grad_norm = max_utils.l2norm_pytree(grads) - # TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually. - updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm) - new_params = optax.apply_updates(state.params, updates) - - new_state = state.replace( - step=state.step + 1, - params=new_params, - opt_state=new_opt_state, - ) + # Move all parameters to device before optimizer update + if config.parameter_memory_host_offload: + max_logging.log("\nMoving all parameters to device before optimizer update") + + def move(path, value): + max_logging.log(f"train.py: Moving f{path} to device") + return value.with_memory_kind(kind="device") + + state = state.replace( + params=jax.device_put( + state.params, + jax.tree_util.tree_map_with_path(move, state_mesh_shardings.params), + ) + ) + # Re-wrap grads to match state.params structure if it's a dict of collections + # (when weight_sparsity is enabled, params has both 'params' and 'batch_stats' keys). + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled: + full_grads = {"params": grads} + if "batch_stats" in state.params: + batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {})) + full_grads["batch_stats"] = batch_stats_grads + full_grads = max_utils.unbox_logicallypartioned(full_grads) + else: + full_grads = grads + + if getattr(config, "skip_step_on_spikes", False): + grad_norm = max_utils.l2norm_pytree(grads) + # TrainState.apply_gradients doesn't pass **kwargs to tx.update, so we unpack it manually. + updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params, loss=loss, grad_norm=grad_norm) + new_params = optax.apply_updates(state.params, updates) + + new_state = state.replace( + step=state.step + 1, + params=new_params, + opt_state=new_opt_state, + ) + else: + new_state = state.apply_gradients(grads=full_grads) + + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") + # Updates the shape to be aligned with state. + moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() + new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) else: - new_state = state.apply_gradients(grads=full_grads) + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + else: + grads = raw_grads + if config.optimizer_memory_host_offload: + # state.optimizer is an NNX Optimizer module; state_mesh_shardings.optimizer + # is an NNX State. Use nnx.state() to get a compatible State for device_put. + device_opt_shardings = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_device, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + opt_state = nnx.state(state.optimizer) + new_opt_state = jax.device_put(opt_state, device_opt_shardings) + nnx.update(state.optimizer, new_opt_state) + state.apply_gradients(grads) + new_state = state - # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family - if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: - target_path = ("params", "decoder", "moe_layers", "DeepSeekMoeBlock_0", "MoeBlock_0", "gate", "bias") - # Flax 'sow' returns a tuple, so we take the first element [0]. - # Updates the shape to be aligned with state. - moe_bias_updates = jnp.array(moe_bias_updates[0]).transpose() - new_state = maxtext_utils.update_state_param(new_state, target_path, moe_bias_updates) + # Apply updates for Auxiliary-Loss-Free load balancing for DeepSeek family + if config.routed_bias and config.routed_bias_update_rate > 0.0 and moe_bias_updates is not None: + target_bias = new_state.model.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.gate.bias + target_bias.value = target_bias.value + jnp.array(moe_bias_updates[0]).transpose() lm_loss = xent_sum / (total_weights + EPS) scalar_metrics = { @@ -458,8 +514,9 @@ def move(path, value): "learning/total_weights": total_weights, } if config.use_qk_clip: - # Apply QK-Clip - new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + # Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX) + if isinstance(model, nn.Module): + new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) # Report max_logits metric global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) @@ -469,7 +526,11 @@ def move(path, value): if not config.optimizer_memory_host_offload: scalar_metrics["learning/grad_norm"] = max_utils.l2norm_pytree(grads) scalar_metrics["learning/raw_grad_norm"] = max_utils.l2norm_pytree(raw_grads) - scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + if isinstance(model, nn.Module): + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_state.params) + else: + _, model_params, _ = nnx.split(new_state.model, nnx.Param, ...) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(model_params) if config.use_dpo: scalar_metrics["learning/dpo_loss"] = aux["dpo_loss"] scalar_metrics["learning/dpo_reward_accuracy"] = aux["reward_accuracy"] @@ -477,33 +538,34 @@ def move(path, value): "scalar": scalar_metrics, "scalars": {}, } - if config.record_internal_nn_metrics: record_activation_metrics(metrics, intermediate_outputs, config) - if config.use_dpo: - new_state = _merge_dpo_state(new_state, reference_params) - - return new_state, metrics + if isinstance(model, nn.Module): + if config.use_dpo: + new_state = _merge_dpo_state(new_state, reference_params) + return new_state, metrics + return nnx.state(new_state), metrics -def eval_step(model, config, state, data, dropout_rng): +def eval_step(model, config, state, data, dropout_rng=None): """eval_step no backprop and new state compared with train_step.""" + if isinstance(model, nn.Module): + reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn + if config.use_dpo: + state, reference_params = _split_dpo_state(state) + extra_dpo_args = [reference_params] + _loss_fn = dpo_loss_fn - reference_params, extra_dpo_args, _loss_fn = [], [], loss_fn - if config.use_dpo: - state, reference_params = _split_dpo_state(state) - extra_dpo_args = [reference_params] - _loss_fn = dpo_loss_fn - - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - pure_params = state.params["params"] if sparsity_enabled else state.params - batch_stats = state.params.get("batch_stats", {}) + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + pure_params = state.params["params"] if sparsity_enabled else state.params + batch_stats = state.params.get("batch_stats", {}) - eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) - loss, aux = eval_loss_fn( - pure_params, *extra_dpo_args, sparsity_state=batch_stats - ) + eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(pure_params, *extra_dpo_args, sparsity_state=batch_stats) + else: + state = nnx.merge(model, state) # reconstruct TrainStateNNX + loss, aux = loss_fn(state.model, config, data, None, None, is_train=False) mtp_acceptance_rate = 0.0 if config.mtp_eval_target_module > 0: @@ -531,7 +593,7 @@ def eval_step(model, config, state, data, dropout_rng): "evaluation/mtp_acceptance_rate_percent": mtp_acceptance_rate, }, } - if config.use_dpo: + if isinstance(model, nn.Module) and config.use_dpo: metrics["scalar"]["evaluation/dpo_reward_accuracy"] = aux["reward_accuracy"] return metrics @@ -553,32 +615,46 @@ def train_loop(config, recorder, state=None): state, ) = train_utils.setup_train_loop(config, recorder) - if config.use_dpo: - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_dpo_state(state, reference_params) - state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if isinstance(model, nn.Module): + if config.use_dpo: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_dpo_state(state, reference_params) + state_mesh_shardings = _merge_dpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + jit_model = model + else: + if config.use_dpo: + raise NotImplementedError("DPO is not supported for NNX models.") + jit_model, state = nnx.split(state) params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) + p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( + config, + jit_model, + mesh, + state, + state_mesh_shardings, + train_step, + eval_step, + eval_data_iterator, + params_shardings, + ) + with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): - p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( - config, - model, - mesh, - state, - state_mesh_shardings, - train_step, - eval_step, - eval_data_iterator, - params_shardings, - ) shaped_batch = maxtext_utils.get_shaped_batch(config) - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - maxtext_utils.maybe_dump_jaxpr(config, p_train_step, (state, shaped_batch, init_rng)) + elif config.shard_optimizer_over_data: + # NNX: reshard state so params match the data-sharded in_shardings (Zero-1 layout) + state = jax.device_put(state, state_mesh_shardings) + if isinstance(model, nn.Module): + lower_args = (state, shaped_batch, init_rng) + else: + lower_args = (state, shaped_batch) + maxtext_utils.maybe_dump_jaxpr(config, p_train_step, lower_args) if config.compiled_trainstep_file == "": # compile only when there is no pre-compiled file loaded - compiled = p_train_step.lower(state, shaped_batch, init_rng).compile() + compiled = p_train_step.lower(*lower_args).compile() compiled_stats = compiled.memory_analysis() max_utils.print_compiled_memory_stats(compiled_stats) @@ -587,7 +663,11 @@ def train_loop(config, recorder, state=None): metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params) + if isinstance(model, nn.Module): + setup_params = state.params + else: + _, setup_params, _ = nnx.split(state.model, nnx.Param, ...) + metric_logger.write_setup_info_to_tensorboard(setup_params) _job_completed_gracefully = False try: @@ -597,57 +677,60 @@ def train_loop(config, recorder, state=None): with jax.profiler.StepTraceAnnotation("train", step_num=step): example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager) - # pylint: disable=not-callable - nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + if isinstance(model, nn.Module): + # pylint: disable=not-callable + step_rng_args = (jax.jit(jax.random.fold_in)(init_rng, step),) + else: + step_rng_args = () with maybe_record_goodput(recorder, GoodputEvent.STEP, step): with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - if config.shard_optimizer_over_data: + if config.shard_optimizer_over_data and isinstance(model, nn.Module): state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) - state, metrics = p_train_step(state, example_batch, nextrng) - - step_time_delta = datetime.datetime.now() - last_step_completion - last_step_completion = datetime.datetime.now() - - state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] - checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) - - if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): - jax.block_until_ready(state) # Ensure compilation has finished. - gcs_utils.upload_dump( - config.dump_hlo_local_dir, - config.dump_hlo_gcs_dir, - module_name=config.dump_hlo_module_name, - delete_local_after=config.dump_hlo_delete_local_after, - all_host_upload=config.dump_hlo_upload_all, - ) - - if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: - assert eval_data_iterator - # Explicitly reset the eval iterator and counters before starting the eval loop - eval_data_iterator.reset() - metric_logger.reset_eval_metrics() - - eval_step_count = 0 - # pylint: disable=not-callable - for eval_batch in eval_data_iterator: - if config.eval_steps > 0 and eval_step_count >= config.eval_steps: - break - with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - eval_metrics = p_eval_step(state, eval_batch, nextrng) - metric_logger.record_eval_metrics(step, metrics=eval_metrics) - max_logging.log(f"Completed eval step {eval_step_count}") - eval_step_count += 1 - metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) - if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: - prof.deactivate() - raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") - - prof.maybe_deactivate_profiler(step, state) - - if step == start_step: - max_utils.print_mem_stats("After params initialized") - - metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) + state, metrics = p_train_step(state, example_batch, *step_rng_args) + + step_time_delta = datetime.datetime.now() - last_step_completion + last_step_completion = datetime.datetime.now() + + state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] + checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) + + if config.dump_hlo and step == (config.dump_step if config.dump_step >= 0 else start_step): + jax.block_until_ready(state) # Ensure compilation has finished. + gcs_utils.upload_dump( + config.dump_hlo_local_dir, + config.dump_hlo_gcs_dir, + module_name=config.dump_hlo_module_name, + delete_local_after=config.dump_hlo_delete_local_after, + all_host_upload=config.dump_hlo_upload_all, + ) + + if config.eval_interval > 0 and step > start_step and (step + 1) % config.eval_interval == 0: + assert eval_data_iterator + # Explicitly reset the eval iterator and counters before starting the eval loop + eval_data_iterator.reset() + metric_logger.reset_eval_metrics() + + eval_step_count = 0 + # pylint: disable=not-callable + for eval_batch in eval_data_iterator: + if config.eval_steps > 0 and eval_step_count >= config.eval_steps: + break + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, *step_rng_args) + metric_logger.record_eval_metrics(step, metrics=eval_metrics) + max_logging.log(f"Completed eval step {eval_step_count}") + eval_step_count += 1 + metric_logger.record_eval_metrics(step, eval_step_count=eval_step_count) + if metric_logger.cumulative_eval_metrics["scalar"]["eval/avg_loss"] <= config.target_eval_loss: + prof.deactivate() + raise exceptions.StopTraining(f"Target loss {config.target_eval_loss=} is achieved.") + + prof.maybe_deactivate_profiler(step, state) + + if step == start_step: + max_utils.print_mem_stats("After params initialized") + + metric_logger.buffer_and_write_train_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: state_to_save = state if not config.use_dpo else _split_dpo_state(state)[0] diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index 9bad1cfb35..e1699647c6 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp from jax.sharding import NamedSharding +from flax import nnx from maxtext.common.common_types import ShardMode from maxtext.utils.sharding import maybe_shard_with_name @@ -49,7 +50,8 @@ def gradient_accumulation_loss_and_grad( config: Model and training configuration object. Must contain `gradient_accumulation_steps` and `shard_optimizer_over_data`. model: The model module. - params: The model parameters (PyTree). + params: The model parameters (PyTree). This is only used for Linen. For NNX, + we can get the params from the model. params_shardings: The sharding constraints for the parameters (PyTree). data: A PyTree of batched data. The leading dimension is assumed to be the total batch size (microbatch_size * num_accumulations). @@ -67,12 +69,18 @@ def _maybe_shard_with_name(inputs, sharding_names): """Wrapper of maybe_shard_with_name with fixed shard_mode""" return maybe_shard_with_name(inputs, sharding_names, config.shard_mode, debug_sharding=config.debug_sharding) + is_nnx = isinstance(model, nnx.Module) + # For more efficient DP/ZeRO-1 + GA if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1: ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings) grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings) else: ga_params_shardings = grad_shardings = params_shardings + + if is_nnx: + graphdef, params, rest = nnx.split(model, nnx.Param, ...) + # When using Zero-1 optimizer sharding, cast params to lower precision and apply sharding constraints # so that all-gather is done once in the lower precision before the gradient accumulation loop if config.shard_optimizer_over_data: @@ -87,11 +95,27 @@ def convert_to_bf16(param): ga_params = params ga_params = jax.tree.map(_maybe_shard_with_name, ga_params, ga_params_shardings) - grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) + if is_nnx: + grad_func = nnx.value_and_grad(_loss_fn, argnums=0, has_aux=True) + else: + grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True) def accumulate_gradient(acc_grad_and_loss, data): ga_params = acc_grad_and_loss["ga_params"] - (_, aux), cur_batch_gradient = grad_func(model, config, data, dropout_rng, ga_params, *extra_dpo_args, is_train=True) + if is_nnx: + # Reconstruct the model using the fixed parameters (ga_params) + # and the advancing non-parameter state (RNGs) from the carry. + local_model = nnx.merge(graphdef, ga_params, acc_grad_and_loss["rest_state"]) + (_, aux), cur_batch_gradient = grad_func(local_model, config, data, None, None, *extra_dpo_args, is_train=True) + _, _, next_rest_state = nnx.split(local_model, nnx.Param, ...) + acc_grad_and_loss["rest_state"] = next_rest_state + else: + rng = ( + jax.random.fold_in(dropout_rng, acc_grad_and_loss["total_weights"].astype(jnp.int32)) + if dropout_rng is not None + else None + ) + (_, aux), cur_batch_gradient = grad_func(model, config, data, rng, ga_params, *extra_dpo_args, is_train=True) acc_grad_and_loss["loss"] += aux["xent_sum"] + aux.get("dpo_loss", 0.0) acc_grad_and_loss["moe_lb_loss"] += aux["moe_lb_loss"] acc_grad_and_loss["indexer_loss"] += aux["indexer_loss"] @@ -119,6 +143,8 @@ def reshape_to_microbatch_accumulations(batch_arr): "mtp_loss": 0.0, "ga_params": ga_params, } + if is_nnx: + init_grad_and_loss["rest_state"] = rest grad_and_loss, aux = jax.lax.scan( accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps @@ -134,6 +160,9 @@ def reshape_to_microbatch_accumulations(batch_arr): raw_grads = jax.tree_util.tree_map(lambda arr: arr / grad_and_loss["total_weights"], raw_grads) aux = jax.tree.map(lambda x: jnp.sum(x, axis=0), aux) # pytype: disable=module-attr + if is_nnx: + nnx.update(model, grad_and_loss["rest_state"]) + return loss, aux, raw_grads diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 5458b35a7d..d03f60766c 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -20,21 +20,20 @@ import os from typing import Sequence -from flax import linen as nn +from flax import nnx, linen as nn +from flax.core.spmd import composite_rules, from_sharding_rules, get_logical_axis_rules from flax.linen import partitioning as nn_partitioning -from flax.training import train_state +from flax.training.train_state import TrainState import numpy as np -from jax.experimental import mesh_utils -from jax.experimental.serialize_executable import deserialize_and_load -from jax.sharding import AxisType, Mesh - import jax import jax.numpy as jnp +from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils +from jax.experimental.serialize_executable import deserialize_and_load import optax - import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager @@ -54,6 +53,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import sharding +from maxtext.utils import maxtext_utils_nnx OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" @@ -101,7 +101,10 @@ def get_functional_train_with_signature( """Get the shardings (both state and data) for `train_step`.""" functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) functional_train.__name__ = "train_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = (state_mesh_shardings, None) # State, metrics static_argnums = () # We partial out the static argnums of model and config donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. @@ -112,7 +115,10 @@ def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shar """Get the shardings (both state and data) for `eval_step`.""" functional_eval = functools.partial(eval_step, model, config) functional_eval.__name__ = "eval_step" - in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + if config.pure_nnx: + in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng) + else: + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng out_shardings = None # metrics static_argnums = () # We partial out the static argnums of model, config donate_argnums = () # state will be kept instead of being donated in eval_step @@ -1201,15 +1207,15 @@ def _apply_update(path, param): return state.replace(params=new_params) -def init_decode_state(apply_fn, params) -> train_state.TrainState: +def init_decode_state(apply_fn, params) -> TrainState: """Init train state with null opt state for decode.""" - state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore + state = TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore return state def init_training_state(apply_fn, params, tx): """Init train state with null opt state for decode.""" - state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) + state = TrainState.create(apply_fn=apply_fn, params=params, tx=tx) return state @@ -1337,7 +1343,7 @@ def setup_initial_state( is_training: True to initialize training state, False for decode state Returns: - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance state_mesh_annotations: the mesh annotations for the train state """ @@ -1376,33 +1382,48 @@ def setup_initial_state( else: # The update of data_iterator state happens in place, no need to assign explicitly state = restored["items"] + + # For NNX, convert the pure dict to nnx.State using the abstract state as template + if config.pure_nnx: + nnx.replace_by_pure_dict(unboxed_abstract_state, state) + state = unboxed_abstract_state else: init_state_partial = init_state_fn init_state_partial.__name__ = "initialize_state" - # pylint: disable=not-callable - state = jax.jit( - init_state_partial, - in_shardings=None, - out_shardings=state_mesh_shardings, - )() - sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m - if ( - sparsity_enabled and raw_params - ): # If we loaded a partial state, we need to merge it. - - def _merge_params(p_raw, p_init): - if isinstance(p_raw, jax.ShapeDtypeStruct): - return p_init - return p_raw - - merged_params = jax.tree_util.tree_map( - _merge_params, raw_params, state.params - ) - state = state.replace(params=merged_params) - elif raw_params: - state = state.replace(params=raw_params) - - state = max_utils.unbox_logicallypartioned(state) + if config.pure_nnx: + state = jax.jit( + lambda: nnx.state(init_state_partial()), # Get state only, mapping to out_sharding structure + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + else: + # pylint: disable=not-callable + state = jax.jit( + init_state_partial, + in_shardings=None, + out_shardings=state_mesh_shardings, + )() + if raw_params: # If we loaded a partial state, we need to merge it. + if config.pure_nnx: + # raw_params should have the same sharding info as in the model + nnx.update(state.model, raw_params) + else: + sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m + if sparsity_enabled: + # Sparsity-init keeps freshly initialized params for any leaf still + # represented as an abstract ShapeDtypeStruct in raw_params (i.e. not + # actually restored), and uses the restored value otherwise. + def _merge_params(p_raw, p_init): + if isinstance(p_raw, jax.ShapeDtypeStruct): + return p_init + return p_raw + + merged_params = jax.tree_util.tree_map(_merge_params, raw_params, state.params) + state = state.replace(params=merged_params) + else: + state = state.replace(params=raw_params) + if not config.pure_nnx: + state = max_utils.unbox_logicallypartioned(state) return state, state_mesh_annotations, state_mesh_shardings, data_iterator @@ -1417,6 +1438,9 @@ def get_logical_annotations(config, mesh, init_state_fn): def get_abstract_state(config, mesh, init_state_fn, is_training=True): """Get a shaped abstraction of the state (including optimizer)""" + if config.pure_nnx: + return get_abstract_state_nnx(config, mesh, init_state_fn, is_training) + init_state_partial = init_state_fn with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -1460,6 +1484,148 @@ def move(path, x): ) +def get_nnx_named_sharding_with_scan_axis(abs_var_state: nnx.State, mesh) -> nnx.State: + """Compute NamedSharding for each NNX variable, correctly handling the scan (stacked layers) axis. + + Unlike flax.nnx.spmd.get_var_pspec (used inside nnx.get_abstract_model), this function also + inserts the partition_name axis at the correct scan_axis position for parameters created by + _create_scanned_layers. Without this, scanned parameters get a 2D partition spec applied to a + 3D tensor, placing sharding on the stacked-layers dimension instead of the embedding dimension. + + Args: + abs_var_state: NNX abstract variable state from nnx.split(nnx.eval_shape(...)). + mesh: JAX physical mesh. + + Returns: + Same tree structure as abs_var_state but each Variable's value replaced with NamedSharding. + """ + + def _make_named_sharding(v): + val = v.get_value() + if not hasattr(val, "shape"): + # Non-tensor value (e.g., optax MaskedNode for non-trainable params). Preserve + # as-is so the treedef matches abs_var_state in the downstream jax.tree.map. + return v + metadata = v.get_metadata() + out_sharding = metadata.get("out_sharding") or metadata.get("sharding_names") or metadata.get("sharding") + if not out_sharding: + pspec = PartitionSpec() + else: + # Insert the scan axis for parameters created by _create_scanned_layers. + # _add_scan_metadata stores the axis name in nnx.PARTITION_NAME and the + # axis index in "param_scan_axis". flax.nnx.spmd.get_var_pspec ignores these. + if nnx.PARTITION_NAME in metadata: + partition_name = metadata[nnx.PARTITION_NAME] + # Always use param_scan_axis from metadata. OptVariable (optimizer state) inherits + # param_scan_axis=1 from the model Param via to_opt_state(), so we must not hardcode + # scan_axis=0 for non-Param types. stacked_rest non-Param variables have + # param_scan_axis=0 set explicitly by _add_scan_metadata, so this is always correct. + scan_axis = metadata.get("param_scan_axis", 0) + out_sharding = [out_sharding] if isinstance(out_sharding, str) else list(out_sharding) + # Guard against double-insertion: Flax 0.12.6 _remap_sharding_metadata renames + # 'sharding' -> 'out_sharding', so _add_scan_metadata may have already inserted + # the scan axis. Only insert if not already present. + if partition_name not in out_sharding: + out_sharding.insert(scan_axis, partition_name) + out_sharding = tuple(out_sharding) + # Convert logical axis names to physical mesh axes using current context rules. + context_rules = get_logical_axis_rules() + local_rules = metadata.get("sharding_rules", ()) + if context_rules or local_rules: + rules = composite_rules(context_rules, local_rules) + pspec = PartitionSpec(*from_sharding_rules(out_sharding, rules)) + else: + pspec = PartitionSpec(*out_sharding) + return v.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_make_named_sharding, abs_var_state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +def get_abstract_state_nnx(config, mesh, nnx_init_trainstate_fn, is_training=True): + """Calculates the abstract sharded state and memory placement for an NNX TrainState. + + This function performs an abstract trace of the NNX model and optimizer using + `nnx.get_abstract_model`. It resolves logical sharding annotations into physical + JAX shardings and applies memory placement optimizations such as optimizer + sharding and host memory offloading (pinning to CPU RAM). + + Args: + config: Configuration object containing sharding and offloading hyperparameters + (e.g., shard_optimizer_over_data, optimizer_memory_host_offload). + mesh: JAX physical mesh used to resolve logical axis names to physical devices. + nnx_init_trainstate_fn: A zero-argument factory function that produces a + TrainStateNNX instance during the abstract trace. + is_training: Boolean indicating if the state is for training. If True, + optimizer state is processed and memory offloading strategies are applied. + + Returns: + A tuple containing (abstract_sharded_state, None, state_mesh_shardings): + abstract_sharded_state: An nnx.State containing ShapeDtypeStructs with + fully resolved physical sharding and memory_kind metadata. + state_mesh_annotations: An nnx.State tree consisting of the raw PartitionSpec + objects corresponding to each parameter/variable. + state_mesh_shardings: An nnx.State tree consisting of the raw JAX + Sharding objects corresponding to each parameter/variable. + """ + assert nnx_init_trainstate_fn is not None, "get_abstract_state_nnx: init function must be given." + + with nn_partitioning.axis_rules(config.logical_axis_rules): + # Use nnx.eval_shape + nnx.split instead of nnx.get_abstract_model, so we can apply + # get_nnx_named_sharding_with_scan_axis which correctly inserts the stacked-layers + # axis into the partition spec. nnx.get_abstract_model uses get_var_pspec internally + # which ignores nnx.PARTITION_NAME / param_scan_axis metadata set by _create_scanned_layers, + # causing the 2D partition spec to be misapplied to the 3D stacked parameter tensor. + # Do NOT wrap nnx.eval_shape in jax.set_mesh: Flax 0.12.6's _to_variable calls + # var.shape for every variable when a global mesh is active, but masked optimizer + # state variables (e.g. from trainable_parameters_mask) have value=MaskedNode() + # which has no .shape and would raise AttributeError. We handle sharding + # ourselves via get_nnx_named_sharding_with_scan_axis, so auto-assignment is not + # needed here. + abs_model = nnx.eval_shape(nnx_init_trainstate_fn) + _, abs_var_state = nnx.split(abs_model) + named_sharding_state = get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + + state_mesh_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) + + if is_training and config.shard_optimizer_over_data: + # Add data to sharding for optimizer state + optimizer_sharding = jax.tree_util.tree_map_with_path( + functools.partial(sharding.add_data_to_sharding, mesh), + abstract_state.optimizer, + state_mesh_shardings.optimizer, + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.optimizer_memory_host_offload: + optimizer_sharding = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_mesh_shardings.optimizer, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + state_mesh_shardings.optimizer = optimizer_sharding + if is_training and config.parameter_memory_host_offload: + assert config.param_scan_axis == 0, "You must set the scan axis 0 to enable parameter offloading." + _, state_params, _ = nnx.split(state_mesh_shardings, nnx.Param, ...) + state_params = jax.tree_util.tree_map_with_path( + maxtext_utils_nnx.move_memory_to_host, + state_params, + is_leaf=lambda x: isinstance(x, NamedSharding), + ) + nnx.update(state_mesh_shardings, state_params) + + abstract_sharded_state = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, state_mesh_shardings) + state_mesh_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) + return ( + abstract_sharded_state, + state_mesh_annotations, + state_mesh_shardings, + ) + + def get_prefill_kv_cache_annotations(model, config, rng, mesh, page_state: None | PageState = None): """Get a shaped abstraction of the state (including optimizer)""" diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index c37e6b52ad..fd068c720c 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -1,3 +1,17 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Copyright 2023–2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,11 +32,11 @@ import dataclasses import collections from collections.abc import Sequence +from typing import Callable, overload from functools import partial import os import subprocess import sys -from typing import overload from etils import epath from flax import nnx import flax.linen as nn @@ -223,34 +237,99 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng return model -def create_nnx_abstract_model(config, mesh, model_mode=MODEL_MODE_TRAIN, rng_key=None): - """Returns (_create_model_partial, abstract_model) for AOT compilation. +def get_nnx_create_model_fn(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None) -> Callable: - This does not shard parameters or load checkpoints. It only builds the - abstract shape/dtype structure needed by get_abstract_state and optimizer - construction (e.g. Muon). + def _create_model(): + rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) + return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - Args: - config: the configuration - mesh: the device mesh - model_mode: train or inference - rng_key: optional RNG key + return _create_model + + +def create_nnx_abstract_model( + config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None +) -> tuple[Callable, nnx.Module]: + """Creates an abstract NNX model. Returns: - (_create_model_partial, abstract_model) where _create_model_partial() creates - a concrete model instance and abstract_model is the eval_shape result. + A tuple containing (create_model_fn, abstract_model): + create_model_fn: A zero-argument callable that produces a new model instance. + abstract_model: The stateful NNX model instance in an abstract state. """ - def _create_model(rng_key=None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) - return from_config(config, mesh=mesh, rngs=rngs, model_mode=model_mode) + with nn.logical_axis_rules(config.logical_axis_rules): + _create_model = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) + if mesh is None: + _tmp = nnx.eval_shape(_create_model) + mesh = _tmp.mesh + # Use nnx.eval_shape + our scan-axis-aware sharding helper instead of + # nnx.get_abstract_model, which uses get_var_pspec internally and ignores + # param_scan_axis / nnx.PARTITION_NAME metadata set by _create_scanned_layers, + # causing the stacked layers axis to be missing from the PartitionSpec. + with jax.set_mesh(mesh): + abs_model = nnx.eval_shape(_create_model) + graphdef, abs_var_state = nnx.split(abs_model) + named_sharding_state = maxtext_utils.get_nnx_named_sharding_with_scan_axis(abs_var_state, mesh) + abstract_state = jax.tree.map( + lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s), + abs_var_state, + named_sharding_state, + ) + return _create_model, nnx.merge(graphdef, abstract_state) + + +def create_nnx_sharded_model_hybrid(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): + """Creates a sharded model for hybrid NNX modules containing Linen sub-modules. - _create_model_partial = partial(_create_model, rng_key=rng_key) + DEPRECATED: This function is a transitional utility for the Linen-to-NNX + migration. It should be removed once all model components are ported to + pure NNX modules. + + This function specifically handles the complexity of "mixed" state initialization, + where logical sharding annotations must be resolved for both NNX native + Parameters and legacy Linen variables wrapped via the NNX-Linen bridge. + It ensures that both systems correctly respect the provided mesh and + logical axis rules during the abstraction/sharding planning phase. + """ + _create_model_partial = get_nnx_create_model_fn(config, mesh, devices, model_mode, rng_key) with nn.logical_axis_rules(config.logical_axis_rules): abstract_model = nnx.eval_shape(_create_model_partial) + graphdef, abstract_state = nnx.split(abstract_model) + specs = nnx.get_partition_spec(abstract_state) + + if mesh is None: + mesh = abstract_model.mesh + + # JIT a function that creates the model state with proper sharding from the start. + # By providing out_shardings, we instruct JAX to produce sharded output directly, + # avoiding a large intermediate allocation on a single device. + with nn.logical_axis_rules(config.logical_axis_rules): + out_shardings = nn.logical_to_mesh_sharding(specs, mesh) - return _create_model_partial, abstract_model + @partial(jax.jit, out_shardings=out_shardings) + def create_sharded_state(): + # This will be JIT-compiled. JAX knows the output sharding and can + # initialize the parameters directly on the target devices in a sharded way. + model = _create_model_partial() + return nnx.state(model) + + with mesh: + # Create the model with sharded parameters. + with nn.logical_axis_rules(config.logical_axis_rules): + sharded_state = create_sharded_state() + model = nnx.merge(graphdef, sharded_state) + + # print weights sharding info under debug sharding mode + if config.debug_sharding: + max_utils.print_non_trivial_mesh_axis(model.mesh) + maxtext_utils.print_shardings_params( + params=sharded_state, + params_sharding=out_shardings, + mesh=model.mesh, + logical_annotations=specs, + ) + return model def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None = None, **extra_kwargs): @@ -435,60 +514,19 @@ def from_pretrained( ) config = pyconfig.HyperParameters(new_config) - def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): - rngs = maxtext_utils_nnx.create_nnx_rngs(config, model_mode=model_mode, rng_key=rng_key) - return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode) - - _create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key) + if config.pure_nnx: + _create_model, abstract_model = create_nnx_abstract_model(config, mesh, devices, model_mode, rng_key) + model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model, mesh=mesh) + # TODO: print debug_sharding info + else: + model = create_nnx_sharded_model_hybrid(config, mesh, devices, model_mode, rng_key) - with nn.logical_axis_rules(config.logical_axis_rules): - abstract_model = nnx.eval_shape(_create_model_partial) - graphdef, abstract_state = nnx.split(abstract_model) - specs = nnx.get_partition_spec(abstract_state) + sharded_state = nnx.state(model) if mesh is None: - mesh = abstract_model.mesh - - # Note for pure_nnx: - # Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and - # we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen - # LogicallyPartitioned structure. - # In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned - # structure in the abstract state and we can get the sharded state with the following code: - # graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh) - # abstract_model = nnx.merge(graphdef, state) - # model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh) - # sharded_state = nnx.state(model) - - # JIT a function that creates the model state with proper sharding from the start. - # By providing out_shardings, we instruct JAX to produce sharded output directly, - # avoiding a large intermediate allocation on a single device. - with nn.logical_axis_rules(config.logical_axis_rules): - out_shardings = nn.logical_to_mesh_sharding(specs, mesh) - - @partial(jax.jit, out_shardings=out_shardings) - def create_sharded_state(): - # This will be JIT-compiled. JAX knows the output sharding and can - # initialize the parameters directly on the target devices in a sharded way. - model = _create_model_partial() - return nnx.state(model) + mesh = model.mesh with mesh: - # Create the model with sharded parameters. - with nn.logical_axis_rules(config.logical_axis_rules): - sharded_state = create_sharded_state() - model = nnx.merge(graphdef, sharded_state) - - # print weights sharding info under debug sharding mode - if config.debug_sharding: - max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - params=sharded_state, - params_sharding=out_shardings, - mesh=model.mesh, - logical_annotations=specs, - ) - if config.load_parameters_path: try: ckptr = ocp.Checkpointer( diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 3ba60d7371..3bd2b186b1 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -24,25 +24,23 @@ python3 -m maxtext.utils.muon_utils qwen3-4b True """ - import os import sys from typing import Optional, Tuple import flax.linen as nn +from flax import nnx import jax from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_PKG_DIR from maxtext.layers import quantizations from maxtext.models import models -from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils, model_creation_utils from optax.contrib._muon import MuonDimensionNumbers as mdn -Transformer = models.transformer_as_linen - - def _is_path_contain_any(tuples, path): + """Checks if any element in 'tuples' is present in 'path'.""" return any(x in path for x in tuples) @@ -107,10 +105,25 @@ def get_transform_tree(tree, path=()): def get_muon_weight_dimension_numbers(model, config, verbose=False): """Extract muon dimension number from model structure.""" - # quickly get param structure without materialization - abstract_param = maxtext_utils.get_abstract_param(model, config) - # get muon dimension number from param - muon_weight_dimension_numbers = get_transform_tree(abstract_param) + + if isinstance(model, nnx.Module): + _, abstract_param, _ = nnx.split(model, nnx.Param, ...) + + def apply_transform_nnx(path: Tuple[jax.tree_util.KeyEntry, ...], leaf): + # Convert jax.tree_util.KeyEntry path to Tuple[str, ...] + path_strings = tuple(p.key for p in path if isinstance(p, jax.tree_util.DictKey)) + return transform_logic(path_strings) + + # Use jax.tree_util.tree_map_with_path for NNX's potentially complex PyTree structure. + # This is different with linen where abstract_param is a dict-based tree with nn.LogicallyPartitioned leaves. + muon_weight_dimension_numbers = jax.tree_util.tree_map_with_path(apply_transform_nnx, abstract_param) + + else: # Linen + # quickly get param structure without materialization + abstract_param = maxtext_utils.get_abstract_param(model, config) + # get muon dimension number from param + muon_weight_dimension_numbers = get_transform_tree(abstract_param) + if verbose: _print_structure_debug(abstract_param, muon_weight_dimension_numbers) return muon_weight_dimension_numbers @@ -118,19 +131,30 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False): def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): """Prints the model structure and the resulting Muon config.""" - # Access the shape from the inner ShapeDtypeStruct and names from the wrapper - # Return a new tree with the same structure containing only shapes/names + + def get_leaf_info(leaf): + # For linen: + # Access the shape from the inner ShapeDtypeStruct and names from the wrapper + # Return a new tree with the same structure containing only shapes/names + if isinstance(leaf, nn.LogicallyPartitioned): + return {"shape": leaf.value.shape, "names": leaf.names} + # For nnx: + # Only return the shape because it doesn't have a wrapper. + elif isinstance(leaf, jax.ShapeDtypeStruct): + return {"shape": leaf.shape} + return {"shape": "N/A"} + info_tree = jax.tree_util.tree_map( - lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, + get_leaf_info, abstract_param, - is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), + is_leaf=lambda x: isinstance(x, (nn.LogicallyPartitioned, jax.ShapeDtypeStruct)), ) print(f"\n=== Model Structure ===\n{info_tree}") print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}") print("\nIs this reasonable?") -def get_model_mdn(model_name, scan_layers=True, verbose=False): +def get_model_mdn(model_name, scan_layers=True, verbose=False, pure_nnx=False): """Initializes a model and retrieves its Muon dimension numbers. This function sets up the configuration for a given model, initializes the @@ -154,13 +178,17 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): f"model_name={model_name}", f"scan_layers={scan_layers}", "attention=dot_product", + f"pure_nnx={pure_nnx}", ] config = pyconfig.initialize(argv) # Setup model devices_array = maxtext_utils.create_device_mesh(config) mesh = jax.sharding.Mesh(devices_array, config.mesh_axes) quant = quantizations.configure_quantization(config) - model = Transformer(config, mesh=mesh, quant=quant) + if pure_nnx: + _, model = model_creation_utils.create_nnx_abstract_model(config, mesh) + else: + model = models.transformer_as_linen(config, mesh=mesh, quant=quant) # Get dimension number muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose) return muon_weight_dimension_numbers @@ -172,4 +200,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): sys.exit(1) model_name_arg = sys.argv[1] scan_layers_arg = sys.argv[2].lower() == "true" - get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) + get_model_mdn(model_name_arg, scan_layers_arg, verbose=True, pure_nnx=False) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index d4bb64f016..4a500e2fe1 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -15,7 +15,7 @@ # pylint: disable=line-too-long, disable=bare-except, consider-using-generator """ Utils that are only interesting to MaxText and sharding related. """ -from flax import linen as nn +from flax import linen as nn, nnx from collections.abc import Iterable @@ -25,6 +25,7 @@ import optax +from maxtext.configs import pyconfig from maxtext.common.common_types import ShardMode from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -483,6 +484,8 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): - updated_state_mesh_shardings: State mesh shardings with updated params field (unchanged if shard_optimizer_over_data is False) """ + if config.pure_nnx: + return maybe_update_params_sharding_with_opt_nnx(config, state_mesh_shardings) prev_params_shardings = state_mesh_shardings.params if config.shard_optimizer_over_data: if isinstance(state_mesh_shardings.opt_state, optax.ScaleByAdamState): @@ -501,6 +504,122 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): return prev_params_shardings, state_mesh_shardings +def maybe_update_params_sharding_with_opt_nnx( + config: pyconfig.HyperParameters, state_mesh_shardings: nnx.State +) -> tuple[nnx.State, nnx.State]: + """ + NNX version of parameter sharding update. Updates parameter sharding configuration + when optimizer state sharding is enabled. + + When shard_optimizer_over_data is enabled (Zero-1 style sharding), this function + extracts the optimizer state shardings from the Adam optimizer's first moment (mu) + and merges them with the parameter shardings. This ensures parameter sharding is + consistent with how the optimizer state is distributed across the compute mesh. + + Args: + config: Configuration with shard_optimizer_over_data flag. + state_mesh_shardings: The sharding state for a TrainStateNNX container. + + Returns: + A tuple of (prev_params_shardings, updated_state_mesh_shardings): + - prev_params_shardings: Original parameter shardings before the update + - updated_state_mesh_shardings: State mesh shardings with updated params field + (unchanged if shard_optimizer_over_data is False)""" + # In TrainStateNNX, parameters are under 'model' + model_shardings = state_mesh_shardings.model + + def _extract_param_only(state): + """Recursively extract nnx.Param variables from an nnx.State into a nested plain dict. + + Constructs nnx.State({'key': nested_dict, ...}) which produces the same pytree + structure as nnx.split(model, nnx.Param, ...)[1], enabling jax.tree.map + to work correctly between ga_params (Param-only) and params_shardings. + """ + result = {} + for k, v in state.items(): + if isinstance(v, nnx.Param): + result[k] = v + elif isinstance(v, nnx.Variable): + pass # skip non-Param variables (RngKey, RngCount, OptVariable, etc.) + elif hasattr(v, "items"): + sub = _extract_param_only(v) + if sub: + result[k] = sub + return result + + # prev_params_shardings must match the pytree structure of ga_params from + # nnx.split(model, nnx.Param, ...) — Param variables only, no rngs. + prev_params_shardings = nnx.State(_extract_param_only(model_shardings)) + + if not config.shard_optimizer_over_data: + return prev_params_shardings, state_mesh_shardings + + sharded_fp32_params = None + # Check if the optimizer has any state at all (stateless optimizers like SGD omit this key) + if "opt_state" in state_mesh_shardings.optimizer: + # Access the optimizer branch to find the optax state + # state_mesh_shardings.optimizer contains the sharding for the nnx.Optimizer + opt_state = state_mesh_shardings.optimizer.opt_state + + def find_adam_mu(obj): + # 1. Direct hit on ScaleByAdamState (Linen path or unflattened NNX) + if isinstance(obj, optax.ScaleByAdamState): + return obj.mu + + # 2. Check for flattened ScaleByAdamState (nnx.State/dict) + # These nodes contain 'mu', 'nu', and 'count' as keys. + if hasattr(obj, "__getitem__") and "mu" in obj and "nu" in obj: + return obj["mu"] + + # 3. Recursive search through containers (nnx.State, dict, list, tuple) + values = None + if hasattr(obj, "values"): # Handles nnx.State and dict + values = obj.values() + elif isinstance(obj, (list, tuple)): + values = obj + + if values: + for v in values: + res = find_adam_mu(v) + if res is not None: + return res + return None + + sharded_fp32_params = find_adam_mu(opt_state) + if sharded_fp32_params is None: + actual_type = type(state_mesh_shardings.optimizer.get("opt_state", "None")) + raise NotImplementedError(f"Could not find Adam optimizer state in: {actual_type}") + + # Update model parameter sharding to match the mu (first moment) sharding. + # This ensures parameter sharding is consistent with the Zero-1 distributed layout. + # Build a path → new_PS lookup from sharded_fp32_params (mu), then update model_shardings + # at those paths while preserving rngs and any other non-Param variables. + mu_leaves_with_paths = list( + jax.tree_util.tree_leaves_with_path(sharded_fp32_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + ) + mu_lookup = {path: mu_var.get_value() for path, mu_var in mu_leaves_with_paths} + + def _update_model_var(path, var): + if path in mu_lookup: + return var.replace(mu_lookup[path]) + return var + + new_model_shardings = jax.tree_util.tree_map_with_path( + _update_model_var, model_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable) + ) + # Use jax.tree_util.tree_map (identity) to create a new nnx.State via JAX's unflatten + # mechanism (not the nnx.State constructor). This is critical because: + # 1. nnx.State({...}) constructor recursively converts nested plain dicts to nnx.State, + # causing a pytree type mismatch with the actual state from nnx.split (which stores + # nested module states as plain dicts). JAX's unflatten preserves the original types. + # 2. copy.deepcopy fails because NamedSharding contains non-picklable jaxlib.Device objects. + # Direct __setattr__ assignment stores new_model_shardings as-is (no type conversion). + updated_state = jax.tree_util.tree_map(lambda x: x, state_mesh_shardings, is_leaf=lambda x: isinstance(x, nnx.Variable)) + updated_state.model = new_model_shardings + + return prev_params_shardings, updated_state + + def logical_axis_rules_pp_act_as_dp(logical_rules): """Add stage as a physical axes before data for each rule, so stage acts just like data instead of PP. This is used when we want to pipeline only a subset of layers, and leave the rest like DP. diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 906a597728..ca90550630 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,12 +15,14 @@ # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting for training in MaxText.""" +import functools import os from functools import partial import jax -import functools +from flax import nnx from flax.linen import partitioning as nn_partitioning +from maxtext.layers import train_state_nnx from maxtext.common import checkpointing from maxtext.common.data_loader import create_dataloader from maxtext.common.goodput import GoodputEvent, maybe_record_goodput @@ -205,7 +207,7 @@ def setup_train_loop(config, recorder, devices=None): data_iterator: data_loader: rampup_manager: the class managing rampup batch sizes - state: the initialized train state + train_state: the initialized train state. For NNX, this is a TrainStateNNX instance """ # pylint: disable=import-outside-toplevel from maxtext.input_pipeline.input_pipeline_interface import create_data_iterator @@ -213,16 +215,22 @@ def setup_train_loop(config, recorder, devices=None): with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): is_training = True init_rng = jax.random.PRNGKey(config.init_weights_seed) + mesh = maxtext_utils.get_mesh_from_config(config, devices) if config.pure_nnx: # Create abstract NNX model. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, mesh, devices) else: model = model_creation_utils.from_config(config, devices) - mesh = model.mesh learning_rate_schedule, tx = create_training_optimizer(config, model) + if config.pure_nnx: - # NNX has a different function to init the training state. - raise NotImplementedError("Pure NNX support has not been implemented yet.") + # For NNX, the train state is wrapped in the TrainStateNNX module. + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + init_state_fn = create_train_state_fn else: init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, is_training, init_rng) checkpoint_manager = create_checkpoint_manager(config, mesh, init_state_fn) @@ -266,6 +274,15 @@ def setup_train_loop(config, recorder, devices=None): state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) + if config.pure_nnx: + with nn_partitioning.axis_rules(config.logical_axis_rules): + # train_state is instance of TrainStateNNX + state_graphdef, _ = nnx.get_abstract_model(init_state_fn, mesh) + _, state_params, _ = nnx.split(state.model, nnx.Param, ...) + _, state_mesh_shardings_params, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + else: + state_params = state.params + state_mesh_shardings_params = state_mesh_shardings.params if config.enable_diloco: with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): @@ -283,17 +300,24 @@ def setup_train_loop(config, recorder, devices=None): # TODO(aireenmei, hengtaoguo): support sharding in vit for multimodal if not config.using_pipeline_parallelism and not config.use_multimodal: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + sharding.assert_params_sufficiently_sharded(state_params, mesh, config.sharding_tolerance) # print weights sharding info under debug sharding mode if config.debug_sharding: - logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + if config.pure_nnx: + # TODO: Study how to get logical annotations of NNX module. Because of eager sharding, we + # probably already lost the logical partition info at this moment. + logical_annotations_params = None + else: + logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) + logical_annotations_params = logical_annotations.params + max_utils.print_non_trivial_mesh_axis(model.mesh) - maxtext_utils.print_shardings_params( - state.params, state_mesh_shardings.params, model.mesh, logical_annotations.params - ) + maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.use_dpo: + if config.pure_nnx: + raise NotImplementedError("DPO is not supported yet by NNX models.") abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training) max_logging.log( "Restoring reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" @@ -318,12 +342,18 @@ def setup_train_loop(config, recorder, devices=None): except FileNotFoundError: step0_restored = None if step0_restored is not None: + # TODO: For pure_nnx, the dpo state manipulation is different. reference_params = step0_restored["items"].params["params"] state = _merge_dpo_state(state, reference_params) else: max_logging.log( "Could not restore reference parameters for DPO from" f" '{os.path.join(str(config.checkpoint_dir), str(0))}'" ) + if config.pure_nnx: + train_state = nnx.merge(state_graphdef, state) + model = train_state.model + else: + train_state = state return ( init_rng, @@ -336,7 +366,7 @@ def setup_train_loop(config, recorder, devices=None): data_loader, rampup_manager, eval_data_iterator, - state, + train_state, ) diff --git a/tests/integration/setup_train_loop_nnx_test.py b/tests/integration/setup_train_loop_nnx_test.py new file mode 100644 index 0000000000..c15c59fd3b --- /dev/null +++ b/tests/integration/setup_train_loop_nnx_test.py @@ -0,0 +1,140 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration test for setup_train_loop with pure_nnx=True. + +setup_train_loop wires together create_nnx_abstract_model, the training optimizer, +the checkpoint manager, the data iterator, and finally nnx.split / nnx.merge to +return a fully-formed TrainStateNNX. This test exercises that wiring end-to-end +on a tiny synthetic config — the goal is to cover the integration glue that the +unit tests in tests/unit/train_utils_nnx_test.py cannot reach. +""" + +import os +import sys +import unittest + +import pytest + +import jax +from flax import nnx + +from maxtext.configs import pyconfig +from maxtext.layers import train_state_nnx +from maxtext.utils.globals import MAXTEXT_ASSETS_ROOT +from maxtext.utils.train_utils import setup_train_loop +from tests.utils.test_helpers import get_test_config_path + + +def _tiny_nnx_pyconfig(**overrides): + """Build a tiny pyconfig suitable for a single-host setup_train_loop run.""" + init_kwargs = { + "run_name": "setup_train_loop_nnx_test", + "enable_checkpointing": False, + "dataset_type": "synthetic", + "model_name": "default", + "pure_nnx": True, + "per_device_batch_size": 1.0, + "base_emb_dim": 8, + "base_num_query_heads": 4, + "base_num_kv_heads": 4, + "base_mlp_dim": 32, + "base_num_decoder_layers": 2, + "head_dim": 128, + "max_target_length": 128, + "vocab_size": 256, + "steps": 1, + "tokenizer_path": os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizers", "tokenizer.llama2"), + "enable_goodput_recording": False, + "enable_checkpoint_cloud_logger": False, + "monitor_goodput": False, + } + init_kwargs.update(overrides) + return pyconfig.initialize([sys.argv[0], get_test_config_path()], **init_kwargs) + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +class SetupTrainLoopNNXIntegrationTest(unittest.TestCase): + """End-to-end check that setup_train_loop returns a usable TrainStateNNX.""" + + def test_pure_nnx_setup_returns_train_state_nnx(self): + config = _tiny_nnx_pyconfig() + + ( + init_rng, + checkpoint_manager, + state_mesh_shardings, + model, + mesh, + learning_rate_schedule, + data_iterator, + data_loader, + rampup_manager, + eval_data_iterator, + train_state, + ) = setup_train_loop(config, recorder=None) + + # The NNX path returns a fully-merged TrainStateNNX (lines 352-354 in train_utils.py). + self.assertIsInstance(train_state, train_state_nnx.TrainStateNNX) + # Optimizer.step starts at 0 for a fresh init. + self.assertEqual(int(train_state.optimizer.step.get_value()), 0) + # The returned model is train_state.model, an NNX module. + self.assertIsInstance(model, nnx.Module) + self.assertIs(model, train_state.model) + + # Sanity for sibling outputs: + self.assertIsNotNone(init_rng) + self.assertIsNotNone(mesh) + self.assertTrue(callable(learning_rate_schedule)) + # data_loader is mandatory; data_iterator may be wrapped/unwrapped. + self.assertIsNotNone(data_loader) + self.assertIsNotNone(data_iterator) + + # state_mesh_shardings (NNX) is an nnx.State and contains a 'model' branch. + self.assertIsInstance(state_mesh_shardings, nnx.State) + self.assertIn("model", state_mesh_shardings) + + # Cleanup: the rest are not asserted on but referenced so linters don't + # flag them as unused — they're part of the public return contract. + del checkpoint_manager, rampup_manager, eval_data_iterator + + def test_pure_nnx_setup_param_only_split_matches_model(self): + """nnx.split(state.model, nnx.Param, ...) must yield a non-empty Param tree + whose structure matches state_mesh_shardings.model after the same split.""" + config = _tiny_nnx_pyconfig() + *_, state_mesh_shardings, model, _, _, _, _, _, _, train_state = setup_train_loop(config, recorder=None) + + _, params, _ = nnx.split(train_state.model, nnx.Param, ...) + _, params_shardings, _ = nnx.split(state_mesh_shardings.model, nnx.Param, ...) + + # Same key-set after nnx.split — this is what setup_train_loop relies on at + # train_utils.py:281-282 to pair state_params with state_mesh_shardings_params. + self.assertEqual(jax.tree_util.tree_structure(params), jax.tree_util.tree_structure(params_shardings)) + self.assertGreater(len(jax.tree.leaves(params)), 0) + + del model + + def test_pure_nnx_dpo_raises_not_implemented(self): + """The use_dpo branch (train_utils.py:319-320) must raise for NNX.""" + # use_dpo requires a few prerequisites; the simplest is to set the flag and + # let setup_train_loop reach the NotImplementedError check before the more + # involved DPO path runs. + config = _tiny_nnx_pyconfig(use_dpo=True) + with self.assertRaises(NotImplementedError): + setup_train_loop(config, recorder=None) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/checkpointing_nnx_load_test.py b/tests/unit/checkpointing_nnx_load_test.py new file mode 100644 index 0000000000..622f19323a --- /dev/null +++ b/tests/unit/checkpointing_nnx_load_test.py @@ -0,0 +1,106 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX branches of load_state_if_possible.""" + +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx + + +class _Model(nnx.Module): + """Tiny single-linear NNX model for restore tests.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + +def _abstract_nnx_state(): + """Build an nnx.State from a TrainStateNNX — same shape that pre_train passes in.""" + model = _Model(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + return nnx.state(train_state_nnx.TrainStateNNX(model, optimizer)) + + +class TestLoadStateIfPossibleNNX(unittest.TestCase): + """Cover the NNX branches in load_state_if_possible.""" + + def test_load_parameters_from_path_splits_nnx_state_for_param_view(self): + """When abstract_unboxed_pre_state is an nnx.State, the function must call + nnx.split(model, nnx.Param, ...) to get the params and forward them to load_params_from_path.""" + abstract = _abstract_nnx_state() + sentinel_restored = {"linear": {"kernel": jnp.ones((2, 1)), "bias": jnp.zeros((1,))}} + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel_restored) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=abstract, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel_restored) + m.assert_called_once() + forwarded_params = m.call_args[0][1] # second positional arg = abstract_unboxed_params + # The forwarded params come from nnx.split(..., nnx.Param, ...) — same key shape as the model. + leaves = jax.tree.leaves(forwarded_params) + self.assertEqual(len(leaves), 2) # linear.kernel + linear.bias + + def test_load_parameters_from_path_uses_state_params_for_linen(self): + """For Linen TrainState, the function must use state.params (not nnx.split).""" + fake_state = mock.Mock(spec=["params"]) + fake_state.params = {"layer": {"kernel": jnp.ones((2, 2))}} + sentinel = object() + + with mock.patch.object(checkpointing, "load_params_from_path", return_value=sentinel) as m: + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="gs://does-not-exist/params", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=fake_state, + ) + + self.assertIsNone(full) + self.assertIs(params, sentinel) + forwarded_params = m.call_args[0][1] + self.assertIs(forwarded_params, fake_state.params) + + def test_no_paths_returns_none_none(self): + """Sanity: with no checkpoint manager and no load paths, the function returns (None, None).""" + full, params = checkpointing.load_state_if_possible( + checkpoint_manager=None, + data_iterator=None, + load_parameters_from_path="", + load_full_state_from_path="", + checkpoint_storage_concurrent_gb=8, + abstract_unboxed_pre_state=_abstract_nnx_state(), + ) + self.assertIsNone(full) + self.assertIsNone(params) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/gradient_accumulation_nnx_test.py b/tests/unit/gradient_accumulation_nnx_test.py new file mode 100644 index 0000000000..6353f02397 --- /dev/null +++ b/tests/unit/gradient_accumulation_nnx_test.py @@ -0,0 +1,159 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX branch of gradient_accumulation_loss_and_grad.""" + +import unittest +from dataclasses import dataclass + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from maxtext.common.common_types import ShardMode +from maxtext.utils import gradient_accumulation + + +@dataclass +class _Cfg: + gradient_accumulation_steps: int = 2 + shard_optimizer_over_data: bool = False + shard_mode: int = ShardMode.AUTO + ici_data_parallelism: int = 1 + debug_sharding: bool = False + + +class _TinyNNX(nnx.Module): + """Single linear layer NNX model.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +def _fake_loss_fn(model, config, data, dropout_rng, params, is_train=True): + """A loss_fn shaped like the production loss_fn but for a tiny linear model. + + Returns (loss, aux) where aux follows the schema gradient_accumulation_loss_and_grad + reads from: xent_sum / total_weights / moe_lb_loss / indexer_loss / mtp_loss. + """ + del config, dropout_rng, params, is_train + pred = model(data["inputs"]) + per_sample_loss = jnp.mean((pred - data["targets"]) ** 2, axis=-1) + xent_sum = jnp.sum(per_sample_loss) + total_weights = jnp.array(per_sample_loss.shape[0], dtype=jnp.float32) + aux = { + "xent_sum": xent_sum, + "total_weights": total_weights, + "moe_lb_loss": jnp.array(0.0), + "indexer_loss": jnp.array(0.0), + "mtp_loss": jnp.array(0.0), + } + return xent_sum / total_weights, aux + + +class TestGradientAccumulationNNX(unittest.TestCase): + """Cover the NNX path of gradient_accumulation_loss_and_grad.""" + + def setUp(self): + self.model = _TinyNNX(rngs=nnx.Rngs(0)) + self.cfg = _Cfg(gradient_accumulation_steps=2) + # 4 examples → 2 microbatches of 2 each + self.data = { + "inputs": jnp.arange(8.0).reshape(4, 2), + "targets": jnp.zeros((4, 1)), + } + + def _params_shardings(self): + """Build a per-leaf NamedSharding tree shaped like nnx.split(model, nnx.Param, ...)[1]. + + Uses a trivial single-device mesh so jax.lax.with_sharding_constraint accepts the + sharding without contradicting the actual device topology. + """ + _, params, _ = nnx.split(self.model, nnx.Param, ...) + mesh = Mesh( + np.array(jax.local_devices()[:1]).reshape( + 1, + ), + ("x",), + ) + ns = NamedSharding(mesh, PartitionSpec()) + return jax.tree.map(lambda _: ns, params) + + def test_nnx_path_runs_and_returns_grad_for_every_param(self): + """The NNX branch must call nnx.value_and_grad and return one gradient per Param.""" + loss, aux, raw_grads = gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, # NNX branch ignores params + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + self.assertTrue(jnp.isfinite(loss)) + self.assertIn("xent_sum", aux) + self.assertIn("total_weights", aux) + grad_leaves = jax.tree.leaves(raw_grads) + self.assertEqual(len(grad_leaves), 2) # linear.kernel + linear.bias + for g in grad_leaves: + self.assertTrue(jnp.all(jnp.isfinite(g))) + + def test_nnx_path_updates_model_rest_state_after_scan(self): + """After accumulation, nnx.update is called on the model with the rest_state from the scan. + + For a TinyNNX (no rngs/dropout), the rest tree is empty but the call path must still + succeed end-to-end without raising — covering the `if is_nnx: nnx.update(...)` branch. + """ + pre_kernel = self.model.linear.kernel.value.copy() + gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + # The kernel itself is a Param — gradient_accumulation_loss_and_grad does not apply + # gradients to params, so the value should be untouched. + self.assertTrue(jnp.allclose(self.model.linear.kernel.value, pre_kernel)) + + def test_nnx_with_shard_optimizer_over_data_casts_to_bf16(self): + """Zero-1 path must convert fp32 params to bf16 before the scan loop.""" + self.cfg.shard_optimizer_over_data = True + # Should not raise; just verify the function runs and returns sensible outputs. + loss, _, raw_grads = gradient_accumulation.gradient_accumulation_loss_and_grad( + _fake_loss_fn, + self.cfg, + self.model, + params=None, + params_shardings=self._params_shardings(), + data=self.data, + dropout_rng=None, + extra_dpo_args=[], + ) + self.assertTrue(jnp.isfinite(loss)) + for g in jax.tree.leaves(raw_grads): + self.assertTrue(jnp.all(jnp.isfinite(g))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/maxtext_utils_test.py b/tests/unit/maxtext_utils_test.py index 7a09750a86..2ef825c9a7 100644 --- a/tests/unit/maxtext_utils_test.py +++ b/tests/unit/maxtext_utils_test.py @@ -15,11 +15,13 @@ """Tests for the common MaxText utilities""" import functools -from typing import Any, Sequence from collections.abc import Callable +from typing import Any, Sequence import unittest from unittest.mock import MagicMock, Mock, patch from dataclasses import dataclass, field +import numpy as np +import optax from flax import linen as nn from flax import nnx @@ -29,6 +31,7 @@ from jax import random, vmap import jax.numpy as jnp from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec +from jax.experimental import mesh_utils from maxtext.configs import pyconfig from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN, ShardMode from maxtext.inference import inference_utils @@ -39,8 +42,7 @@ from maxtext.utils import sharding from maxtext.utils.sharding import assert_params_sufficiently_sharded, get_formatted_sharding_annotations from tests.utils.test_helpers import get_test_config_path, get_decoupled_parallelism_overrides -import numpy as np -import optax +from maxtext.utils import maxtext_utils_nnx Transformer = models.transformer_as_linen @@ -179,11 +181,7 @@ def setUp(self): "decoder": {"gate": {"bias": jnp.array([0.5, 0.5])}}, } self.state = train_state.TrainState( - step=0, - apply_fn=self.model.apply, - params=self.initial_params, - tx=None, - opt_state={}, + step=0, apply_fn=self.model.apply, params=self.initial_params, tx=None, opt_state={} ) def test_update_mode_add(self): @@ -196,10 +194,10 @@ def test_update_mode_add(self): self.assertTrue(jnp.allclose(actual, expected)) # Other values are untouched - original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] + original_layer_0 = self.state.params["layers"]["layer_0"]["bias"] # pylint: disable=unsubscriptable-object new_layer_0 = new_state.params["layers"]["layer_0"]["bias"] self.assertTrue(jnp.array_equal(original_layer_0, new_layer_0)) - original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] + original_layer_1 = self.state.params["layers"]["layer_1"]["bias"] # pylint: disable=unsubscriptable-object new_layer_1 = new_state.params["layers"]["layer_1"]["bias"] self.assertTrue(jnp.array_equal(original_layer_1, new_layer_1)) @@ -264,7 +262,7 @@ def test_init_training_state(self): @nnx.register_variable_name("special_variables") -class SpecialVariables(nnx.Variable): +class SpecialVariables(nnx.Variable): # pylint: disable=abstract-method pass @@ -281,7 +279,7 @@ def __call__(self, x, y, encoder_images=None, nnx_method=None, model_mode=None): return x -class TrainState(train_state.TrainState): +class TrainState(train_state.TrainState): # pylint: disable=abstract-method other_variables: nnx.State @@ -993,49 +991,63 @@ def train_step(_model, _config, _state_shardings, _params_shardings, state, _bat return train_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_step() result = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(len(result), 5) def test_functional_train_has_correct_name(self): step = self._make_mock_step() fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(fn.__name__, "train_step") - def test_in_shardings_structure(self): + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" step = self._make_mock_step() _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=False) ) - # (state, batch, rng) self.assertEqual(len(in_shardings), 3) self.assertIsNone(in_shardings[2]) # rng sharding is None + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_train_with_signature( + step, "data_sharding", "state_shardings", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + def test_donate_argnums_is_zero(self): step = self._make_mock_step() _, _, _, _, donate_argnums = maxtext_utils.get_functional_train_with_signature( - step, "data_sharding", "state_shardings", "model", "config" + step, "data_sharding", "state_shardings", "model", self._make_mock_config() ) self.assertEqual(donate_argnums, 0) def test_functional_train_is_partial(self): """functional_train should partially apply model and config.""" received = {} + cfg = self._make_mock_config() def train_step(model, config, _state_shardings, _params_shardings, state, _batch, _rng=None): received["model"] = model received["config"] = config return state, {} - fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", "my_config") + fn, _, _, _, _ = maxtext_utils.get_functional_train_with_signature(train_step, "ds", "ss", "my_model", cfg) fn("state", "batch") self.assertEqual(received["model"], "my_model") - self.assertEqual(received["config"], "my_config") + self.assertEqual(received["config"], cfg) class TestGetFunctionalEvalWithSignature(unittest.TestCase): @@ -1047,26 +1059,51 @@ def eval_step(_model, _config, _state, _batch, _rng=None): return eval_step + def _make_mock_config(self, pure_nnx=False): + cfg = MagicMock() + cfg.pure_nnx = pure_nnx + return cfg + def test_returns_five_tuple(self): step = self._make_mock_eval_step() - result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + result = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(len(result), 5) def test_functional_eval_has_correct_name(self): step = self._make_mock_eval_step() - fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + fn, _, _, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", self._make_mock_config()) self.assertEqual(fn.__name__, "eval_step") def test_out_shardings_is_none(self): step = self._make_mock_eval_step() - _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, out_shardings, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertIsNone(out_shardings) def test_donate_argnums_is_empty(self): step = self._make_mock_eval_step() - _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature(step, "ds", "ss", "model", "config") + _, _, _, _, donate_argnums = maxtext_utils.get_functional_eval_with_signature( + step, "ds", "ss", "model", self._make_mock_config() + ) self.assertEqual(donate_argnums, ()) + def test_nnx_in_shardings_excludes_rng(self): + """pure_nnx=True: in_shardings should be (state, batch) — no rng slot.""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=True) + ) + self.assertEqual(len(in_shardings), 2) + + def test_linen_in_shardings_includes_rng(self): + """pure_nnx=False: in_shardings should be (state, batch, rng).""" + step = self._make_mock_eval_step() + _, in_shardings, _, _, _ = maxtext_utils.get_functional_eval_with_signature( + step, "batch_sharding", "state_sharding", "model", self._make_mock_config(pure_nnx=False) + ) + self.assertEqual(len(in_shardings), 3) + class TestGetShapedBatch(unittest.TestCase): """Tests for get_shaped_batch.""" @@ -1414,5 +1451,183 @@ def test_runs_without_logical_annotations(self): maxtext_utils.print_shardings_params(params, param_sharding, mesh=self.mesh, logical_annotations=None) +class TestNNXAbstractState(unittest.TestCase): + """Test the get_abstract_state_nnx func.""" + + @dataclass + class MockConfig: + init_weights_seed: int = 42 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + param_scan_axis: int = 0 + logical_axis_rules: list = field(default_factory=lambda: [["data", ["data"]]]) + + class MockTrainState(nnx.Module): + """Simulates a TrainState with params and optimizer state.""" + + def __init__(self, rngs: nnx.Rngs): + # Model parameters + device_num = len(jax.local_devices()) + self.params = nnx.Linear( + 2, 4, kernel_init=nnx.with_partitioning(nnx.initializers.ones, sharding=("model",)), rngs=rngs + ) + # Simulated optimizer state + self.optimizer = nnx.Variable(jnp.zeros((device_num,)), sharding=("model",)) + + def setUp(self): + # Create a real 1D mesh on local devices + devices = jax.local_devices() + self.mesh = Mesh(mesh_utils.create_device_mesh((len(devices), 1)), axis_names=("model", "data")) + self.config = self.MockConfig() + + def nnx_init_trainstate_wrapper(self): + """Wrapper to initialize the mock NNX model.""" + rngs = maxtext_utils_nnx.create_nnx_rngs(self.config) + return self.MockTrainState(rngs) + + def test_basic_abstraction(self): + """Verifies the basic return structure and partition spec extraction.""" + abstract_state, annotations, shardings = maxtext_utils.get_abstract_state_nnx( + self.config, self.mesh, self.nnx_init_trainstate_wrapper + ) + + # Check return types + self.assertIsInstance(abstract_state, nnx.State) + self.assertIsInstance(annotations, nnx.State) + self.assertIsInstance(shardings, nnx.State) + + # Verify PartitionSpec was extracted correctly from the mock model's annotations + # Path: params -> kernel -> spec + self.assertEqual( + annotations.params.kernel.get_value(), + PartitionSpec( + "model", + ), + ) + + def test_shard_optimizer_over_data(self): + """Verifies that 'data' is added to optimizer sharding using the real utility.""" + self.config.shard_optimizer_over_data = True + + _, annotations, _ = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Original Pspec for optimizer was PartitionSpec(None). + # add_data_to_sharding should find that dim 0 is compatible with mesh 'data' + # and update it to PartitionSpec(('data',)). + opt_spec = annotations.optimizer.get_value() + + # Verify 'data' is now in the spec + self.assertEqual(opt_spec, PartitionSpec(("data", "model"))) + + def test_optimizer_host_offload(self): + """Verifies that optimizer memory is moved to host when configured.""" + self.config.optimizer_memory_host_offload = True + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Optimizer state should be pinned to host + opt_sharding = shardings.optimizer.get_value() + self.assertEqual(opt_sharding.memory_kind, "pinned_host") + + # Params should still be on default memory (usually device) + param_sharding = shardings.params.kernel.get_value() + self.assertNotEqual(param_sharding.memory_kind, "pinned_host") + + def test_parameter_host_offload(self): + """Verifies that parameter memory is moved to host when configured.""" + self.config.parameter_memory_host_offload = True + self.config.param_scan_axis = 0 + + _, _, shardings = maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, self.nnx_init_trainstate_wrapper) + + # Parameters should be pinned to host + param_sharding = shardings.params.kernel.get_value() + self.assertEqual(param_sharding.memory_kind, "pinned_host") + + def test_invalid_init_fn(self): + """Ensures function raises error if no init function is provided.""" + with self.assertRaises(AssertionError): + maxtext_utils.get_abstract_state_nnx(self.config, self.mesh, None) + + +class TestGetNnxNamedShardingWithScanAxis(unittest.TestCase): + """Unit tests for get_nnx_named_sharding_with_scan_axis covering every branch. + + The helper resolves a NamedSharding for each NNX Variable and — unlike + flax.nnx.spmd.get_var_pspec — also inserts the `nnx.PARTITION_NAME` axis at + `param_scan_axis` when scanned-layers metadata is present. + """ + + def setUp(self): + # Mesh needs to contain every axis name the tests reference in partition specs. + self.mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("fsdp", "layers")) + + def _build_state(self, **variables): + """Wrap a dict of {key: nnx.Variable} in an nnx.State for tree traversal.""" + return nnx.State(variables) + + def _run(self, state): + return maxtext_utils.get_nnx_named_sharding_with_scan_axis(state, self.mesh) + + def test_scan_axis_inserted_at_param_scan_axis(self): + """When PARTITION_NAME is present, the partition name is inserted at `param_scan_axis`.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((3, 4, 8)), + out_sharding=(None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 1}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + # 'layers' must be inserted at position 1 (param_scan_axis=1). + self.assertEqual(result_sharding.spec, PartitionSpec(None, "layers", "fsdp")) + + def test_scan_axis_not_inserted_when_already_present(self): + """Guard against double-insertion when partition_name is already in out_sharding.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((2, 2, 2)), + out_sharding=("layers", None, "fsdp"), + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # 'layers' must appear exactly once — the same PartitionSpec we started with. + self.assertEqual(result_sharding.spec, PartitionSpec("layers", None, "fsdp")) + + def test_masked_node_preserved_as_is(self): + """Values without a .shape attribute (e.g., optax.MaskedNode) are returned unchanged.""" + masked = nnx.Variable(optax.MaskedNode()) + state = self._build_state(masked=masked) + out = self._run(state) + # The leaf must be the original Variable, not a NamedSharding wrapper. + self.assertIs(out["masked"], masked) + + def test_empty_out_sharding_yields_empty_pspec(self): + """A Variable without any sharding metadata should resolve to PartitionSpec().""" + with jax.set_mesh(self.mesh): + # No out_sharding/sharding_names/sharding metadata → falsy → PartitionSpec() + v = nnx.Param(jnp.zeros((4,))) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + self.assertIsInstance(result_sharding, NamedSharding) + self.assertEqual(result_sharding.spec, PartitionSpec()) + + def test_string_out_sharding_is_wrapped_into_tuple(self): + """A single-string out_sharding value should still produce a valid PartitionSpec.""" + with jax.set_mesh(self.mesh): + v = nnx.Param( + jnp.zeros((4,)), + out_sharding="fsdp", + **{nnx.PARTITION_NAME: "layers", "param_scan_axis": 0}, + ) + out = self._run(self._build_state(w=v)) + result_sharding = out["w"].get_value() + # The single string 'fsdp' is turned into a list, and 'layers' is prepended. + self.assertEqual(result_sharding.spec, PartitionSpec("layers", "fsdp")) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/muon_utils_test.py b/tests/unit/muon_utils_test.py new file mode 100644 index 0000000000..9570257eee --- /dev/null +++ b/tests/unit/muon_utils_test.py @@ -0,0 +1,224 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for muon_utils.py.""" + +# pylint: disable=protected-access + +import io +import contextlib +import unittest +from unittest import mock + +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax import nnx +from optax.contrib._muon import MuonDimensionNumbers as mdn + +from maxtext.utils import muon_utils + + +class TestIsPathContainAny(unittest.TestCase): + """Tests for _is_path_contain_any helper.""" + + def test_returns_true_when_any_element_in_path(self): + self.assertTrue(muon_utils._is_path_contain_any(("bias", "scale"), ("decoder", "bias"))) + + def test_returns_false_when_no_element_in_path(self): + self.assertFalse(muon_utils._is_path_contain_any(("bias", "scale"), ("decoder", "kernel"))) + + def test_empty_tuples_returns_false(self): + self.assertFalse(muon_utils._is_path_contain_any((), ("decoder", "kernel"))) + + +class TestTransformLogic(unittest.TestCase): + """Tests for transform_logic: covers every branch of the mapping.""" + + # --- 1. Exclusions --- + def test_scale_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "norm", "scale"))) + + def test_bias_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "dense", "bias"))) + + def test_embedding_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("token_embedder", "embedding"))) + + def test_logits_dense_is_excluded(self): + self.assertIsNone(muon_utils.transform_logic(("decoder", "logits_dense", "kernel"))) + + # --- 2.1 MoE --- + def test_moe_wi_0_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wi_0")), mdn((-2,), (-1,))) + + def test_moe_wi_1_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wi_1")), mdn((-2,), (-1,))) + + def test_moe_wo_uses_last_two_axes(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "wo")), mdn((-2,), (-1,))) + + def test_moe_gate_falls_through_to_standard(self): + # 'gate' is inside MoeBlock_0 but not one of (wi_0, wi_1, wo) → standard. + self.assertEqual(muon_utils.transform_logic(("decoder", "MoeBlock_0", "gate", "kernel")), mdn((0,), (-1,))) + + # --- 2.2 Self-attention --- + def test_self_attention_out_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "out")), mdn((0, -2), (-1,))) + + def test_self_attention_query_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "query")), mdn((0,), (-2, -1))) + + def test_self_attention_key_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "key")), mdn((0,), (-2, -1))) + + def test_self_attention_value_projection(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "value")), mdn((0,), (-2, -1))) + + def test_self_attention_wq_b_and_wkv_b(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wq_b")), mdn((0,), (-2, -1))) + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wkv_b")), mdn((0,), (-2, -1))) + + def test_self_attention_mla_wq_a_is_excluded_from_special(self): + # wq_a / wkv_a are MLA down-projections; they fall through the self_attention branch + # without matching anything, so the function returns the default standard mdn((0,), (-1,)). + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wq_a")), mdn((0,), (-1,))) + self.assertEqual(muon_utils.transform_logic(("decoder", "self_attention", "wkv_a")), mdn((0,), (-1,))) + + # --- 3. Standard --- + def test_standard_weight(self): + self.assertEqual(muon_utils.transform_logic(("decoder", "mlp", "kernel")), mdn((0,), (-1,))) + + +class TestGetTransformTree(unittest.TestCase): + """Tests for get_transform_tree: recursive dict walk that applies transform_logic.""" + + def test_nested_dict_is_walked(self): + tree = {"decoder": {"self_attention": {"out": 0}, "mlp": {"kernel": 0}}} + result = muon_utils.get_transform_tree(tree) + self.assertEqual(result["decoder"]["self_attention"]["out"], mdn((0, -2), (-1,))) + self.assertEqual(result["decoder"]["mlp"]["kernel"], mdn((0,), (-1,))) + + def test_excluded_leaves_become_none(self): + tree = {"decoder": {"norm": {"scale": 0}}} + self.assertIsNone(muon_utils.get_transform_tree(tree)["decoder"]["norm"]["scale"]) + + def test_non_dict_leaf_at_root_returns_transform(self): + # If the tree itself is a leaf, path=() and transform_logic returns the standard mdn. + self.assertEqual(muon_utils.get_transform_tree(0), mdn((0,), (-1,))) + + +class _MoeLikeNNXModel(nnx.Module): + """Small NNX model whose param paths exercise the NNX branch of get_muon_weight_dimension_numbers.""" + + def __init__(self, rngs): + # Names are chosen so transform_logic matches each of the three meaningful branches: + # - w_standard: default mdn + # - self_attention_out: attention-out mdn + # - scale: excluded (None) + self.w_standard = nnx.Param(jnp.ones((4, 8))) + self.self_attention_out = nnx.Param(jnp.ones((4, 8))) + self.scale = nnx.Param(jnp.ones((8,))) + + +class TestGetMuonWeightDimensionNumbersNNX(unittest.TestCase): + """Covers the NNX branch of get_muon_weight_dimension_numbers (isinstance(model, nnx.Module)).""" + + def setUp(self): + self.model = _MoeLikeNNXModel(rngs=nnx.Rngs(0)) + + def test_nnx_model_dispatches_to_tree_map_with_path(self): + """NNX branch should produce an nnx.State tree with transform_logic applied per leaf.""" + result = muon_utils.get_muon_weight_dimension_numbers(self.model, config=None) + + # Result is an nnx.State whose top-level keys mirror the model attributes. + self.assertIn("w_standard", result) + self.assertIn("self_attention_out", result) + self.assertIn("scale", result) + + # NNX Variables are walked by jax.tree_util.tree_map_with_path, so the returned + # tree replaces each Variable's value with transform_logic(path_strings). + # 'scale' matches the exclusion branch → value is None. + self.assertIsNone(result["scale"].get_value()) + # 'w_standard' does not trigger any special rule → standard mdn. + self.assertEqual(result["w_standard"].get_value(), mdn((0,), (-1,))) + + def test_nnx_verbose_path_executes_print_debug(self): + """verbose=True should also execute _print_structure_debug without raising.""" + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils.get_muon_weight_dimension_numbers(self.model, config=None, verbose=True) + self.assertIn("Model Structure", buf.getvalue()) + self.assertIn("Muon Dimension Numbers", buf.getvalue()) + + +class TestGetMuonWeightDimensionNumbersLinen(unittest.TestCase): + """Covers the Linen branch of get_muon_weight_dimension_numbers.""" + + def test_linen_branch_uses_get_abstract_param(self): + """Linen models dispatch to maxtext_utils.get_abstract_param + get_transform_tree.""" + # Build a Linen nn.Module so isinstance(model, nnx.Module) is False. + + class LinenStub(nn.Module): + + @nn.compact + def __call__(self, x): + return x + + model = LinenStub() + + # Mock the heavy get_abstract_param call with a pre-shaped dict that exercises + # both a standard weight path and an excluded path. + fake_abstract_param = { + "params": { + "self_attention": {"out": object()}, + "norm": {"scale": object()}, + }, + } + + with mock.patch.object(muon_utils.maxtext_utils, "get_abstract_param", return_value=fake_abstract_param): + result = muon_utils.get_muon_weight_dimension_numbers(model, config=mock.MagicMock()) + + self.assertEqual(result["params"]["self_attention"]["out"], mdn((0, -2), (-1,))) + self.assertIsNone(result["params"]["norm"]["scale"]) + + +class TestPrintStructureDebug(unittest.TestCase): + """Covers both branches of get_leaf_info inside _print_structure_debug.""" + + def test_handles_logically_partitioned_leaf(self): + """Linen leaves are nn.LogicallyPartitioned; the helper should return {shape, names}.""" + leaf = nn.LogicallyPartitioned(value=jax.ShapeDtypeStruct((4, 8), jnp.float32), names=("embed", "mlp")) + tree = {"params": {"kernel": leaf}} + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils._print_structure_debug(tree, muon_weight_dimension_numbers={"params": {"kernel": mdn((0,), (-1,))}}) + out = buf.getvalue() + self.assertIn("(4, 8)", out) + self.assertIn("embed", out) + + def test_handles_shape_dtype_struct_leaf(self): + """NNX abstract leaves are ShapeDtypeStruct directly; the helper should return {shape}.""" + tree = {"kernel": jax.ShapeDtypeStruct((16, 32), jnp.float32)} + + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + muon_utils._print_structure_debug(tree, muon_weight_dimension_numbers={"kernel": mdn((0,), (-1,))}) + out = buf.getvalue() + self.assertIn("(16, 32)", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/nnx_decoders_test.py b/tests/unit/nnx_decoders_test.py index 8979440732..00f761c2cf 100644 --- a/tests/unit/nnx_decoders_test.py +++ b/tests/unit/nnx_decoders_test.py @@ -31,7 +31,12 @@ from flax import nnx from jax.sharding import Mesh -from maxtext.common.common_types import DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_TRAIN, DecoderBlockType +from maxtext.common.common_types import ( + DECODING_ACTIVE_SEQUENCE_INDICATOR, + MODEL_MODE_TRAIN, + DecoderBlockType, + MultimodalInput, +) from maxtext.configs import pyconfig from maxtext.layers import linears from maxtext.layers.attentions import Attention @@ -507,6 +512,88 @@ def test_logits_are_finite(self): ) self.assertTrue(jnp.all(jnp.isfinite(logits))) + def test_multimodal_input_unpacks_into_individual_fields(self): + """Passing `multimodal_input=...` must forward each field into `_apply_embedding`. + + The decoder accepts either a `MultimodalInput` struct or the individual + image/audio/bidirectional_mask arguments. When both forms are provided, the + unpacked struct takes precedence. This test stubs `_apply_embedding` to + capture the forwarded positional arguments without running the real + embedding path (the test config has `use_multimodal=False`). + """ + ids, segment_ids, positions = self._make_token_inputs() + + # Distinct sentinels so each field can be traced independently. + sentinel_img_emb = jnp.full((1, 1), 11.0) + sentinel_img_mask = jnp.full((1, 1), 22.0) + sentinel_aud_emb = jnp.full((1, 1), 33.0) + sentinel_aud_mask = jnp.full((1, 1), 44.0) + sentinel_bidir = jnp.full((1, 1), 55.0) + + mm_input = MultimodalInput( + image_embeddings=sentinel_img_emb, + image_masks=sentinel_img_mask, + audio_embeddings=sentinel_aud_emb, + audio_masks=sentinel_aud_mask, + bidirectional_mask=sentinel_bidir, + ) + + captured = {} + + def fake_apply_embedding( + _shared_embedding, + _ids, + _positions, + _deterministic, + _model_mode, + image_embeddings, + bidirectional_mask, + image_masks, + audio_embeddings, + audio_masks, + ): + captured.update( + image_embeddings=image_embeddings, + image_masks=image_masks, + audio_embeddings=audio_embeddings, + audio_masks=audio_masks, + bidirectional_mask=bidirectional_mask, + ) + # Return a correctly-shaped tensor so the rest of __call__ can proceed. + batch = self.cfg.global_batch_size_to_train_on + seq_len = self.cfg.max_target_length + emb_dim = self.cfg.emb_dim + return jnp.zeros((batch, seq_len, emb_dim), dtype=self.cfg.dtype) + + self.decoder._apply_embedding = fake_apply_embedding # pylint: disable=protected-access + try: + self.decoder( + self.shared_embedding, + ids, + positions, + decoder_segment_ids=segment_ids, + deterministic=True, + model_mode=MODEL_MODE_TRAIN, + # Intentionally pass the individual args as None; multimodal_input must override them. + image_embeddings=None, + image_masks=None, + audio_embeddings=None, + audio_masks=None, + bidirectional_mask=None, + multimodal_input=mm_input, + ) + finally: + # NNX modules bind attributes statefully; remove the override to avoid leaking. + del self.decoder._apply_embedding # pylint: disable=protected-access + + # Every field in the MultimodalInput struct must have been forwarded + # unchanged into _apply_embedding's arguments (not the None overrides). + self.assertTrue(jnp.array_equal(captured["image_embeddings"], sentinel_img_emb)) + self.assertTrue(jnp.array_equal(captured["image_masks"], sentinel_img_mask)) + self.assertTrue(jnp.array_equal(captured["audio_embeddings"], sentinel_aud_emb)) + self.assertTrue(jnp.array_equal(captured["audio_masks"], sentinel_aud_mask)) + self.assertTrue(jnp.array_equal(captured["bidirectional_mask"], sentinel_bidir)) + def test_different_random_seeds_produce_different_logits(self): """Two randomly-initialised decoders should not produce identical logits.""" cfg = self.cfg diff --git a/tests/unit/optimizers_test.py b/tests/unit/optimizers_test.py index 44623f24f3..5194719ce2 100644 --- a/tests/unit/optimizers_test.py +++ b/tests/unit/optimizers_test.py @@ -15,19 +15,19 @@ """ Unit tests for all optimizers. """ import re import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock import jax import optax import jax.numpy as jnp import pytest from absl.testing import parameterized +from flax import nnx from optax.contrib import MuonDimensionNumbers as mdn from maxtext.configs import pyconfig from maxtext.optimizers import optimizers -from maxtext.utils import maxtext_utils -from maxtext.utils.muon_utils import get_model_mdn +from maxtext.utils import maxtext_utils, muon_utils from tests.utils.test_helpers import get_test_config_path from typing import NamedTuple @@ -49,6 +49,7 @@ DEEPSEEK2_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -57,6 +58,7 @@ }, **_DEEPSEEK2_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -73,8 +75,6 @@ }, **_DEEPSEEK2_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -99,6 +99,7 @@ DEEPSEEK3_DIMENSION_NUMBER = { "params": { "decoder": { + "decoder_norm": {"scale": None}, "dense_layers": { "mlp": { "wi_0": {"kernel": mdn((0,), (-1,))}, @@ -107,6 +108,7 @@ }, **_DEEPSEEK3_ATTENTION, }, + "logits_dense": {"kernel": None}, "moe_layers": { "DeepSeekMoeBlock_0": { "MoeBlock_0": { @@ -123,8 +125,6 @@ }, **_DEEPSEEK3_ATTENTION, }, - "decoder_norm": {"scale": None}, - "logits_dense": {"kernel": None}, }, "token_embedder": {"embedding": None}, } @@ -243,7 +243,7 @@ def test_model_integration(self, model_name, expected_output): Initializes the specified MaxText model and asserts that the generated Muon dimension numbers match the hardcoded reference. """ - actual_output = get_model_mdn(model_name, scan_layers=True) + actual_output = muon_utils.get_model_mdn(model_name, scan_layers=True, pure_nnx=False) self.assertEqual(actual_output, expected_output) @@ -483,5 +483,105 @@ def test_no_skip_without_kwargs(self): self.assertEqual(opt_state["count"], 0) +class TestMuonLogic(unittest.TestCase): + """Tests the granular path transformation functions.""" + + def test_is_path_contain_any(self): + # pylint: disable=protected-access + self.assertTrue(muon_utils._is_path_contain_any(("a", "b"), ("x", "a", "z"))) + self.assertFalse(muon_utils._is_path_contain_any(("a", "b"), ("x", "y", "z"))) + + def test_transform_logic_exclusions(self): + self.assertIsNone(muon_utils.transform_logic(("layer_0", "bias"))) + self.assertIsNone(muon_utils.transform_logic(("layer_0", "scale"))) + self.assertIsNone(muon_utils.transform_logic(("embedding", "kernel"))) + + def test_transform_logic_moe(self): + path = ("layers_0", "MoeBlock_0", "wi_0") + result = muon_utils.transform_logic(path) + self.assertEqual(result.reduction_axis, (-2,)) + self.assertEqual(result.output_axis, (-1,)) + + def test_transform_logic_attention(self): + path_out = ("layers_0", "self_attention", "out", "kernel") + self.assertEqual(muon_utils.transform_logic(path_out), mdn((0, -2), (-1,))) + + path_q = ("layers_0", "self_attention", "query", "kernel") + self.assertEqual(muon_utils.transform_logic(path_q), mdn((0,), (-2, -1))) + + def test_get_transform_tree(self): + fake_tree = {"params": {"layer_0": {"kernel": "leaf", "bias": "leaf"}, "MoeBlock_0": {"wi_0": "leaf"}}} + result = muon_utils.get_transform_tree(fake_tree) + self.assertEqual(result["params"]["layer_0"]["kernel"], mdn((0,), (-1,))) + self.assertIsNone(result["params"]["layer_0"]["bias"]) + + def test_get_muon_weight_dimension_numbers_nnx(self): + """Verifies dimension extraction for stateful NNX modules.""" + + class MockNNXModel(nnx.Module): + """Mock NNX Module.""" + + def __init__(self, rngs: nnx.Rngs): + # 1. Standard layer + self.layer1 = nnx.Linear(2, 4, rngs=rngs) + + # 2. MoE specific naming to trigger transform logic. + # The logic expects "MoeBlock_0" AND "wi_0"/"wi_1"/"wo" in the path. + # We nest the linear layer to create the path: ('MoeBlock_0', 'wi_0', 'kernel') + self.MoeBlock_0 = nnx.Module() + self.MoeBlock_0.wi_0 = nnx.Linear(4, 2, rngs=rngs) + + # 3. Exclusion case (scaler/scale) + self.scale = nnx.Param(jnp.ones((1,))) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: MockNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + + # Extract dimension numbers using the NNX path in muon_utils + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Verify standard weight path: ('layer1', 'kernel') -> default (0,) + self.assertEqual(result.layer1.kernel.value, mdn((0,), (-1,))) + + # Verify MoE weight path: ('MoeBlock_0', 'wi_0', 'kernel') -> (-2,) + self.assertEqual(result.MoeBlock_0.wi_0.kernel.value, mdn((-2,), (-1,))) + + # Verify exclusion (scalar/scale) + self.assertIsNone(result.scale.value) + + def test_verbose_output_nnx(self): + """Covers lines 128 and 135-154: _print_structure_debug via verbose=True with NNX model.""" + + class SimpleNNXModel(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + model = nnx.eval_shape(lambda: SimpleNNXModel(rngs=nnx.Rngs(0))) + config = MagicMock() + muon_utils.get_muon_weight_dimension_numbers(model, config, verbose=True) + + def test_nnx_deepseek_attention_logic(self): + """Simulates a DeepSeek-like attention structure in NNX.""" + + class DeepSeekAttention(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.self_attention = nnx.Module() + self.self_attention.query = nnx.Linear(8, 8, rngs=rngs) + self.self_attention.out = nnx.Linear(8, 8, rngs=rngs) + + # Use eval_shape to create an abstract version of the model. + model = nnx.eval_shape(lambda: DeepSeekAttention(nnx.Rngs(0))) + config = MagicMock() + result = muon_utils.get_muon_weight_dimension_numbers(model, config) + + # Check attention query: [0] -> [-2, -1] + self.assertEqual(result.self_attention.query.kernel.value, mdn((0,), (-2, -1))) + # Check attention out: [0, -2] -> [-1] + self.assertEqual(result.self_attention.out.kernel.value, mdn((0, -2), (-1,))) + + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/sharding_nnx_test.py b/tests/unit/sharding_nnx_test.py new file mode 100644 index 0000000000..3cda286c68 --- /dev/null +++ b/tests/unit/sharding_nnx_test.py @@ -0,0 +1,161 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-specific helpers in maxtext.utils.sharding.""" + +import unittest +from dataclasses import dataclass + +import jax +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from flax import nnx +import numpy as np +import optax + +from maxtext.layers import train_state_nnx +from maxtext.utils import sharding + + +@dataclass +class _Cfg: + pure_nnx: bool = True + shard_optimizer_over_data: bool = False + + +class _LinearNNX(nnx.Module): + """Tiny NNX model with a single Linear layer for sharding tests.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 4, rngs=rngs) + + +def _build_state_mesh_shardings(model, tx): + """Build an nnx.State of NamedShardings mirroring the TrainStateNNX layout. + + This emulates what get_abstract_state_nnx returns: an nnx.State whose leaves + are nnx.Variable wrappers around NamedSharding objects. + """ + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + state_obj = train_state_nnx.TrainStateNNX(model, optimizer) + state = nnx.state(state_obj) + mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model")) + + def _to_sharding(var): + val = var.get_value() + if not hasattr(val, "shape") or val.ndim == 0: + pspec = PartitionSpec() + elif val.ndim == 1: + pspec = PartitionSpec("model") + else: + pspec = PartitionSpec("data", "model") + return var.replace(NamedSharding(mesh, pspec)) + + return jax.tree.map(_to_sharding, state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +class TestMaybeUpdateParamsShardingWithOptNNX(unittest.TestCase): + """Cover the NNX branches of maybe_update_params_sharding_with_opt.""" + + def setUp(self): + self.model = _LinearNNX(rngs=nnx.Rngs(0)) + + def test_dispatch_from_main_helper_when_pure_nnx(self): + """maybe_update_params_sharding_with_opt should dispatch to the NNX variant.""" + cfg = _Cfg(pure_nnx=True, shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, updated = sharding.maybe_update_params_sharding_with_opt(cfg, state_mesh_shardings) + # prev is the param-only view (no rngs / non-Param nodes) + self.assertIsInstance(prev, nnx.State) + self.assertIn("linear", prev) + # updated is unchanged because shard_optimizer_over_data=False + self.assertIs(updated, state_mesh_shardings) + + def test_extract_param_only_skips_non_param_variables(self): + """prev_params_shardings must contain Params only — RngKey/RngCount/OptVariable filtered out.""" + cfg = _Cfg(shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + prev, _ = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + leaves = jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable)) + # Every surviving leaf is wrapped as an nnx.Param. + self.assertTrue(all(isinstance(leaf, nnx.Param) for leaf in leaves)) + # The model has linear.kernel and linear.bias — exactly two Param leaves. + self.assertEqual(len(leaves), 2) + + def test_returns_unchanged_when_shard_optimizer_over_data_false(self): + """When shard_optimizer_over_data=False, the second return value must be the input object.""" + cfg = _Cfg(shard_optimizer_over_data=False) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + self.assertIs(updated, state_mesh_shardings) + + def test_zero1_propagates_mu_sharding_to_model_params(self): + """Zero-1: model param shardings must be replaced with the optimizer mu shardings.""" + cfg = _Cfg(shard_optimizer_over_data=True) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.adam(1e-3)) + + # Mutate the optimizer mu leaves in place so the function picks up a distinct PartitionSpec. + mesh = Mesh(np.array(jax.local_devices()[:1]).reshape(1, 1), ("data", "model")) + target_pspec = PartitionSpec(("data", "model")) + new_mu_sharding = NamedSharding(mesh, target_pspec) + + # After _build_state_mesh_shardings, every leaf's .value is a NamedSharding (no .shape), + # so we just override every Variable leaf in mu in place. + # After _build_state_mesh_shardings, every leaf's value is a NamedSharding (no .shape), + # so we just override every Variable leaf in mu in place via set_value (modern API). + mu_state = state_mesh_shardings.optimizer.opt_state[0]["mu"] + for var in jax.tree.leaves(mu_state, is_leaf=lambda x: isinstance(x, nnx.Variable)): + if isinstance(var, nnx.Variable): + var.set_value(new_mu_sharding) + + _, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + + # All Param leaves under updated.model must now share the new mu sharding. + param_leaves = jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + param_leaves = [v for v in param_leaves if isinstance(v, nnx.Param)] + self.assertGreater(len(param_leaves), 0) + for leaf in param_leaves: + self.assertEqual(leaf.get_value().spec, target_pspec) + + def test_raises_when_no_adam_state_present(self): + """Stateless optimizers (e.g., SGD) have no mu — function must raise NotImplementedError.""" + cfg = _Cfg(shard_optimizer_over_data=True) + state_mesh_shardings = _build_state_mesh_shardings(self.model, optax.sgd(1e-3)) + with self.assertRaises(NotImplementedError): + sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + + def test_chained_optimizer_recursion_finds_adam_mu(self): + """A nested optax.chain(clip, adam) wraps mu under multiple containers — recursion must find it.""" + cfg = _Cfg(shard_optimizer_over_data=True) + chained = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3)) + state_mesh_shardings = _build_state_mesh_shardings(self.model, chained) + + # Should not raise; verify update happens (params replaced with mu shardings). + prev, updated = sharding.maybe_update_params_sharding_with_opt_nnx(cfg, state_mesh_shardings) + self.assertIsInstance(prev, nnx.State) + self.assertIsInstance(updated, nnx.State) + # Same number of Param leaves before and after. + n_prev = len(jax.tree.leaves(prev, is_leaf=lambda x: isinstance(x, nnx.Variable))) + n_after = len( + [ + v + for v in jax.tree.leaves(updated.model, is_leaf=lambda x: isinstance(x, nnx.Variable)) + if isinstance(v, nnx.Param) + ] + ) + self.assertEqual(n_prev, n_after) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_nnx_test.py b/tests/unit/train_nnx_test.py new file mode 100644 index 0000000000..3495b4c557 --- /dev/null +++ b/tests/unit/train_nnx_test.py @@ -0,0 +1,239 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX paths of loss_fn / train_step / eval_step in pre_train.train. + +These tests exercise the NNX branches without standing up a real Transformer or +data pipeline. We use a tiny NNX module that mimics the call signature the +production loss_fn uses (decoder_input_tokens, decoder_positions, ...). +""" + +import unittest +from dataclasses import dataclass + +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx +from maxtext.trainers.pre_train import train as pre_train + + +@dataclass +class _Cfg: + """Subset of HyperParameters used by loss_fn / train_step / eval_step.""" + + micro_batch_size_to_train_on: int = 2 + micro_batch_size_to_eval_on: int = 2 + vocab_size: int = 8 + z_loss_multiplier: float = 0.0 + enable_dropout: bool = False + use_multimodal: bool = False + use_indexer: bool = False + indexer_sparse_training: bool = False + indexer_loss_scaling_factor: float = 0.0 + num_vocab_tiling: int = 1 + num_experts: int = 1 + routed_bias: bool = False + routed_bias_update_rate: float = 0.0 + mtp_num_layers: int = 0 + mtp_eval_target_module: int = 0 + use_dpo: bool = False + use_qk_clip: bool = False + use_tunix_gradient_accumulation: bool = False + gradient_accumulation_steps: int = 1 + shard_optimizer_over_data: bool = False + optimizer_memory_host_offload: bool = False + parameter_memory_host_offload: bool = False + gradient_clipping_threshold: float = 0.0 + grad_dtype: jnp.dtype = jnp.float32 + record_internal_nn_metrics: bool = False + skip_step_on_spikes: bool = False + shard_mode: int = 0 # ShardMode.AUTO + weight_sparsity_n: int = 0 + weight_sparsity_m: int = 0 + + +class _TinyDecoder(nnx.Module): + """Mimics NNXDecoder.__call__ enough for loss_fn to run end-to-end. + + Returns logits of shape [batch, seq_len, vocab_size]. Ignores all multimodal + / dropout / target arguments — they exist only to match the keyword signature. + """ + + def __init__(self, vocab_size: int, hidden: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, hidden, rngs=rngs) + self.proj = nnx.Linear(hidden, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + encoder_images=None, + encoder_image_masks=None, + enable_dropout=False, + decoder_target_tokens=None, + decoder_target_mask=None, + ): + del decoder_positions, decoder_segment_ids, encoder_images, encoder_image_masks + del enable_dropout, decoder_target_tokens, decoder_target_mask + h = self.embed(decoder_input_tokens) + return self.proj(h) + + +def _make_data(batch=2, seq=4, vocab=8): + return { + "inputs": jnp.zeros((batch, seq), dtype=jnp.int32), + "inputs_position": jnp.broadcast_to(jnp.arange(seq), (batch, seq)), + "inputs_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + "targets": jnp.zeros((batch, seq), dtype=jnp.int32), + "targets_segmentation": jnp.ones((batch, seq), dtype=jnp.int32), + } + + +def _build_state(): + cfg = _Cfg() + model = _TinyDecoder(cfg.vocab_size, hidden=4, rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, optax.sgd(0.01), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + return cfg, ts + + +class TestLossFnNNX(unittest.TestCase): + """Cover the NNX branch of loss_fn (lines 178-213).""" + + def test_returns_loss_and_full_aux_dict(self): + cfg, ts = _build_state() + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + self.assertTrue(jnp.isfinite(loss)) + # Aux schema relied on by train_step / eval_step / GA. + for key in ( + "intermediate_outputs", + "xent_sum", + "z_loss", + "total_weights", + "moe_lb_loss", + "indexer_loss", + "moe_bias_updates", + "mtp_loss", + ): + self.assertIn(key, aux) + # NNX intermediates are captured into a pure-dict snapshot, then logits attached. + self.assertIsInstance(aux["intermediate_outputs"], dict) + self.assertIn("logits", aux["intermediate_outputs"]) + + def test_eval_mode_truncates_to_eval_micro_batch(self): + cfg, ts = _build_state() + cfg.micro_batch_size_to_eval_on = 1 + data = _make_data(batch=2, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=False) + self.assertTrue(jnp.isfinite(loss)) + # eval truncated batch to 1 → total_weights = seq_len * 1 + self.assertEqual(int(aux["total_weights"]), data["targets_segmentation"].shape[1]) + + def test_indexer_dense_warmup_skips_xent(self): + cfg, ts = _build_state() + cfg.use_indexer = True + cfg.indexer_sparse_training = False + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + loss, aux = pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + # When dense warm-up is active the loss_fn skips the main loss entirely. + self.assertEqual(float(aux["xent_sum"]), 0.0) + self.assertEqual(float(loss), 0.0) + + def test_vocab_tiling_raises_not_implemented(self): + cfg, ts = _build_state() + cfg.num_vocab_tiling = 4 + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + with self.assertRaises(NotImplementedError): + pre_train.loss_fn(ts.model, cfg, data, None, None, is_train=True) + + +class TestTrainStepNNX(unittest.TestCase): + """Cover the NNX branch of train_step (the diff_wrapper / nnx.update path).""" + + def test_train_step_returns_state_and_metrics(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + # NNX path returns nnx.State (via nnx.state(new_state)) and a metrics dict. + self.assertIsInstance(new_state, nnx.State) + self.assertIn("scalar", metrics) + self.assertIn("learning/loss", metrics["scalar"]) + self.assertIn("learning/grad_norm", metrics["scalar"]) + self.assertIn("learning/param_norm", metrics["scalar"]) + self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + + def test_train_step_dpo_raises_for_nnx(self): + cfg, ts = _build_state() + cfg.use_dpo = True + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + with self.assertRaises(NotImplementedError): + pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + + def test_train_step_increments_optimizer_step(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + pre_step = int(state_pure.optimizer.step.get_value()) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, _ = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertEqual(int(new_state.optimizer.step.get_value()), pre_step + 1) + + def test_train_step_with_gradient_clipping(self): + """The clipping branch (gradient_clipping_threshold > 0) must run without raising.""" + cfg, ts = _build_state() + cfg.gradient_clipping_threshold = 1.0 + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_train_on, vocab=cfg.vocab_size) + new_state, metrics = pre_train.train_step( + state_graphdef, cfg, state_mesh_shardings=None, params_shardings=None, state=state_pure, data=data + ) + self.assertIsInstance(new_state, nnx.State) + self.assertTrue(jnp.isfinite(metrics["scalar"]["learning/loss"])) + + +class TestEvalStepNNX(unittest.TestCase): + """Cover the NNX branch of eval_step (lines 568-570).""" + + def test_eval_step_returns_metrics(self): + cfg, ts = _build_state() + state_graphdef, state_pure = nnx.split(ts) + data = _make_data(batch=cfg.micro_batch_size_to_eval_on, vocab=cfg.vocab_size) + metrics = pre_train.eval_step(state_graphdef, cfg, state_pure, data) + self.assertIn("scalar", metrics) + for key in ( + "evaluation/loss", + "evaluation/total_loss", + "evaluation/total_weights", + "evaluation/moe_lb_loss", + ): + self.assertIn(key, metrics["scalar"]) + # NNX path must NOT include DPO eval metric. + self.assertNotIn("evaluation/dpo_reward_accuracy", metrics["scalar"]) + self.assertTrue(jnp.isfinite(metrics["scalar"]["evaluation/loss"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_checkpoint_test.py b/tests/unit/train_state_nnx_checkpoint_test.py new file mode 100644 index 0000000000..100d3f81e1 --- /dev/null +++ b/tests/unit/train_state_nnx_checkpoint_test.py @@ -0,0 +1,399 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX checkpoint tests.""" + +import pathlib +import tempfile +import shutil +from types import SimpleNamespace +from unittest import mock + +import unittest +import jax +import jax.numpy as jnp +from flax import nnx, serialization +from flax import linen as nn +from flax.training import train_state +import optax +import orbax.checkpoint as ocp + +from maxtext.common import checkpointing +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """A simple model for checkpoint testing.""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class LinenMockModel(nn.Module): + """The Linen equivalent of the MockModel.""" + + @nn.compact + def __call__(self, x): + # We name the layer 'linear' to match the attribute name in the NNX MockModel + return nn.Dense(features=1, name="linear")(x) + + +class TestTrainStateNNXCheckpoint(unittest.TestCase): + """Class to test NNX checkpoint.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + + # Setup a chained optimizer: Gradient Clipping -> Adam + # Note: optax.adam is also a chain (scale_by_adam + scale_by_learning_rate). + # This creates a nested state structure: (EmptyState, (ScaleByAdamState, EmptyState)) + self.tx = optax.chain( + optax.clip_by_global_norm(max_norm=1.0), + optax.adam(1e-3), + ) + + def test_checkpoint_structure(self): + """Ensures the state object contains both model and optimizer keys.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # We use .to_pure_dict() to simulate the format stored in a checkpoint. + # This converts nnx.Variable/State objects into raw arrays and dictionaries. + full_state = nnx.state(state).to_pure_dict() + + # 1. Verify Top-level Keys + self.assertIn("model", full_state) + self.assertIn("optimizer", full_state) + + # 2. Verify Optimizer Internal Structure + opt_inner_state = full_state["optimizer"]["opt_state"] + + # Because we used optax.chain(clip, adam), index 0 is clip, index 1 is adam. + # Since adam is also a chain, index 1 is itself a dictionary/tuple representation. + # Adam's momentum (mu/nu) is in the first element of its own sub-chain. + adam_component = opt_inner_state[1][0] + + self.assertIn("mu", adam_component, "Adam 'mu' buffer not found in pure dict state.") + self.assertIn("nu", adam_component, "Adam 'nu' buffer not found in pure dict state.") + + # In a pure dict, these are nested dictionaries containing arrays, not NNX objects. + self.assertIsInstance(adam_component["mu"], dict) + self.assertIsInstance(adam_component["nu"], dict) + + # To verify a specific leaf, we navigate the dictionary hierarchy: + self.assertIsInstance(adam_component["mu"]["linear"]["kernel"], jax.Array) + + def test_checkpoint_and_restore(self): + """Verifies that the full state can be captured and restored into a new instance.""" + # 1. Initialize original state and optimizer + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state_original = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # 2. Perform a training step to modify weights and optimizer buffers + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state_original.model) + state_original.apply_gradients(grads) + + # Capture state after one step + original_kernel_val = state_original.model.linear.kernel.value + original_step_val = state_original.optimizer.step.value + self.assertEqual(original_step_val, 1) + + # 3. Capture the "Checkpoint" as a pure dictionary + checkpoint_state = nnx.state(state_original).to_pure_dict() + + # 4. Initialize a fresh, different instance + new_rngs = nnx.Rngs(1) + new_model = MockModel(rngs=new_rngs) + new_optimizer = nnx.Optimizer(new_model, self.tx, wrt=nnx.Param) + state_restored = train_state_nnx.TrainStateNNX(new_model, new_optimizer) + + # Check differences before restoration + self.assertEqual(state_restored.optimizer.step.value, 0) + self.assertFalse(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # 5. Restore the state into the new instance. + # nnx.update supports updating from a pure dictionary. + nnx.update(state_restored, checkpoint_state) + + # 6. Verify restoration + # Check step counter + self.assertEqual(state_restored.optimizer.step.value, original_step_val) + # Check model weights + self.assertTrue(jnp.allclose(state_restored.model.linear.kernel.value, original_kernel_val)) + + # Check that it can still be trained after restoration + new_grads = nnx.grad(loss_fn)(state_restored.model) + state_restored.apply_gradients(new_grads) + self.assertEqual(state_restored.optimizer.step.value, 2) + + def test_restore_from_linen_state(self): + """Verifies a multi-stage migration: Linen CKPT -> Migrate -> NNX CKPT -> Restore.""" + # 1. Setup Linen TrainState (Simulating original training) + linen_model = LinenMockModel() + dummy_input = jnp.ones((1, 2)) + variables = linen_model.init(jax.random.key(42), dummy_input) + + state_linen = train_state.TrainState.create(apply_fn=linen_model.apply, params=variables["params"], tx=self.tx) + + # Perform a step to populate optimizer buffers + grads = jax.tree.map(jnp.ones_like, state_linen.params) + state_linen = state_linen.apply_gradients(grads=grads) + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save Legacy Linen Checkpoint --- + linen_ckpt_dir = temp_dir / "linen_ckpt" + mngr_linen = ocp.CheckpointManager( + linen_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr_linen.save(0, args=ocp.args.StandardSave(state_linen)) + mngr_linen.wait_until_finished() + + # --- PHASE 2: Read Linen CKPT and Convert to NNX Structure --- + # Load it back without knowing the blueprint (reading as a pure PyTree) + restored_linen_obj = mngr_linen.restore(0) + + # Convert the restored object to a pure dictionary structure. + restored_linen_dict = serialization.to_state_dict(restored_linen_obj) + + # Helper to recursively convert string keys back to integers + # and filter out None values. + def recursive_clean(obj): + if isinstance(obj, dict): + return {int(k) if k.isdigit() else k: recursive_clean(v) for k, v in obj.items() if v is not None} + return obj + + # Converted dict - simple PyTree mapping, no NNX Module initialization needed here. + # This simulates a situation where the conversion logic is blueprint-agnostic. + linen_as_nnx_dict = { + "model": restored_linen_dict["params"], + "optimizer": { + "step": jnp.array(restored_linen_dict["step"]), + "opt_state": recursive_clean(restored_linen_dict["opt_state"]), + }, + } + + # --- PHASE 3: Save as Native NNX Checkpoint --- + nnx_ckpt_dir = temp_dir / "nnx_ckpt" + mngr_nnx = ocp.CheckpointManager( + nnx_ckpt_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + # We save the raw dictionary directly to disk. + mngr_nnx.save(0, args=ocp.args.StandardSave(linen_as_nnx_dict)) + mngr_nnx.wait_until_finished() + + # --- PHASE 4: Restore from NNX Checkpoint to target Model --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We now restore using the nnx.State as a blueprint. This ensures Orbax + # correctly maps the arrays on disk to the model's structural expectation. + blueprint = nnx.state(state_nnx).to_pure_dict() + restored_nnx_pytree = mngr_nnx.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + nnx.update(state_nnx, restored_nnx_pytree) + + # --- PHASE 5: Verification --- + # 1. Verify Step + self.assertEqual(state_nnx.optimizer.step.value, 1) + + # 2. Verify Weights + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, state_linen.params["linear"]["kernel"])) + + # 3. Verify Chained Optimizer State (Clip at index 0, Adam at index 1) + self.assertEqual(type(state_nnx.optimizer.opt_state[0]), type(state_linen.opt_state[0])) + + # state_linen.opt_state[1] is the Adam chain state. + # state_linen.opt_state[1][0] is the ScaleByAdamState containing 'mu'. + self.assertTrue( + jnp.allclose( + state_nnx.optimizer.opt_state[1][0].mu["linear"]["kernel"], + state_linen.opt_state[1][0].mu["linear"]["kernel"], + ) + ) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + def test_restore_from_checkpoint_model_params(self): + """Verifies that model parameters can be restored from model params only.""" + # 1. Setup mocked parameters manually (no Linen model needed for setup) + # This structure matches the path model.linear.kernel/bias in the NNX MockModel. + mock_params = {"linear": {"kernel": jnp.ones((2, 1)) * 9.0, "bias": jnp.zeros((1,))}} + + # Simplified checkpoint dictionary using hardcoded mocked params as requested + checkpoint_dict = { + "model": mock_params, + } + + temp_dir = pathlib.Path(tempfile.mkdtemp()) + try: + # --- PHASE 1: Save the partial checkpoint --- + mngr = ocp.CheckpointManager( + temp_dir, options=ocp.CheckpointManagerOptions(create=True), item_handlers=ocp.StandardCheckpointHandler() + ) + mngr.save(0, args=ocp.args.StandardSave(checkpoint_dict)) + mngr.wait_until_finished() + + # --- PHASE 2: Restore into a full TrainStateNNX --- + nnx_model = MockModel(rngs=nnx.Rngs(0)) + nnx_optimizer = nnx.Optimizer(nnx_model, self.tx, wrt=nnx.Param) + state_nnx = train_state_nnx.TrainStateNNX(nnx_model, nnx_optimizer) + + # We use nnx.state to get a full blueprint as a reference. + full_nnx_pure_dict = nnx.state(state_nnx).to_pure_dict() + blueprint = {"model": full_nnx_pure_dict["model"]} + + # If we don't know if the checkpoint on disk has 'optimizer' or not, we simulate + # schema-agnostic restoration by calling restore without a blueprint. + # This avoids Orbax structural mismatch errors while allowing us to see the data. + restored_pytree = mngr.restore(0, args=ocp.args.StandardRestore(item=blueprint)) + + # Use nnx.update to apply the restored data to the stateful NNX object. + # nnx.update is naturally partial: it will update 'model' from the restored dict + # and leave 'optimizer' untouched at its initialized value. + nnx.update(state_nnx, restored_pytree) + + # --- PHASE 3: Verification --- + # Check that weights were restored to the specific mock values + self.assertTrue(jnp.allclose(state_nnx.model.linear.kernel.value, mock_params["linear"]["kernel"])) + # Step remains at its initialized value (0) because it was not in the checkpoint + self.assertEqual(state_nnx.optimizer.step.value, 0) + + # Verify that the optimizer state still exists in the object (initialized) + # even though it was not provided in the checkpoint. + # Adam's state is at index 1 of the chain, and it's a nested structure (tuple). + # We verify that index 0 (ScaleByAdamState) contains the 'mu' State container. + self.assertIsInstance(state_nnx.optimizer.opt_state[1][0].mu, nnx.State) + + finally: + # Cleanup temporary directory + shutil.rmtree(temp_dir) + + +class TestMaybeSaveCheckpointStepAlignment(unittest.TestCase): + """Verify maybe_save_checkpoint's fallback step matches the last completed step. + + When the training loop's final save calls maybe_save_checkpoint without an + explicit `step`, it derives `actual_step` from the state: + - NNX: int(state.optimizer.step) - 1 + - Linen: int(state.step) - 1 + Both TrainStateNNX.apply_gradients (via nnx.Optimizer.update) and Linen + TrainState.apply_gradients increment the counter by 1 per call, so after N + gradient applications the counter is N and the "last completed step" is N-1. + """ + + N_STEPS = 5 + + def setUp(self): + self.tx = optax.adam(1e-3) + + def _build_nnx_state(self, num_steps): + """Build an nnx.State flattened from TrainStateNNX after num_steps gradient applications.""" + model = MockModel(rngs=nnx.Rngs(0)) + optimizer = nnx.Optimizer(model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(model, optimizer) + + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + for _ in range(num_steps): + grads = nnx.grad(loss_fn)(state.model) + state.apply_gradients(grads) + # maybe_save_checkpoint is called with a flat nnx.State in the NNX path + # (train_step returns nnx.state(new_state)). + return nnx.state(state) + + def _build_linen_state(self, num_steps): + """Build a Linen TrainState after num_steps gradient applications.""" + model = LinenMockModel() + variables = model.init(jax.random.key(0), jnp.ones((1, 2))) + state = train_state.TrainState.create(apply_fn=model.apply, params=variables["params"], tx=self.tx) + grads = jax.tree.map(jnp.ones_like, state.params) + for _ in range(num_steps): + state = state.apply_gradients(grads=grads) + return state + + def _invoke_maybe_save(self, state, pure_nnx): + """Call maybe_save_checkpoint with save_checkpoint patched, return {step, state} captured.""" + # checkpoint_period=1 keeps force_ckpt_save False regardless of actual_step. + config = SimpleNamespace(pure_nnx=pure_nnx, checkpoint_period=1, async_checkpointing=False) + mgr = mock.MagicMock() + mgr.reached_preemption.return_value = False + + captured = {} + + def fake_save_checkpoint(_mgr, step, state_arg, *_args, **_kwargs): + captured["step"] = step + captured["state"] = state_arg + return False # no save happened => print_save_message is skipped + + with mock.patch.object(checkpointing, "save_checkpoint", side_effect=fake_save_checkpoint): + checkpointing.maybe_save_checkpoint(mgr, state, config, data_iterator=None, step=None) + return captured + + def test_nnx_final_save_step_is_n_minus_1(self): + state = self._build_nnx_state(self.N_STEPS) + self.assertEqual(int(state.optimizer.step.value), self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=True) + self.assertEqual(captured["step"], self.N_STEPS - 1) + + def test_linen_final_save_step_is_n_minus_1(self): + state = self._build_linen_state(self.N_STEPS) + self.assertEqual(int(state.step), self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=False) + self.assertEqual(captured["step"], self.N_STEPS - 1) + + def test_nnx_and_linen_agree_on_actual_step(self): + """TrainStateNNX and Linen TrainState must yield the same fallback actual_step.""" + nnx_state = self._build_nnx_state(self.N_STEPS) + linen_state = self._build_linen_state(self.N_STEPS) + self.assertEqual( + self._invoke_maybe_save(nnx_state, pure_nnx=True)["step"], + self._invoke_maybe_save(linen_state, pure_nnx=False)["step"], + ) + + def test_nnx_state_is_converted_to_pure_dict_before_save(self): + """For pure_nnx=True, maybe_save_checkpoint must pass a plain dict to save_checkpoint, not an nnx.State.""" + state = self._build_nnx_state(self.N_STEPS) + self.assertIsInstance(state, nnx.State) # precondition: NNX train_step returns an nnx.State + + captured = self._invoke_maybe_save(state, pure_nnx=True) + + # save_checkpoint should have received a plain Python dict (the result of + # nnx.State.to_pure_dict()), not the original nnx.State. + self.assertIsInstance(captured["state"], dict) + self.assertNotIsInstance(captured["state"], nnx.State) + # Sanity: the converted dict still mirrors the TrainStateNNX structure. + self.assertIn("model", captured["state"]) + self.assertIn("optimizer", captured["state"]) + + def test_linen_state_is_passed_through_unchanged(self): + """For pure_nnx=False, maybe_save_checkpoint must pass the original TrainState object through.""" + state = self._build_linen_state(self.N_STEPS) + captured = self._invoke_maybe_save(state, pure_nnx=False) + # Linen path must not invoke to_pure_dict(); state is forwarded as-is. + self.assertIs(captured["state"], state) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_state_nnx_test.py b/tests/unit/train_state_nnx_test.py new file mode 100644 index 0000000000..03db77ff63 --- /dev/null +++ b/tests/unit/train_state_nnx_test.py @@ -0,0 +1,90 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TrainStateNNX tests.""" + +import unittest +import jax.numpy as jnp +from flax import nnx +import optax + +from maxtext.layers import train_state_nnx + + +class MockModel(nnx.Module): + """Mocked NNX model""" + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + def __call__(self, x): + return self.linear(x) + + +class TestTrainStateNNX(unittest.TestCase): + """TrainStateNNX tests.""" + + def setUp(self): + self.rngs = nnx.Rngs(0) + self.model = MockModel(rngs=self.rngs) + self.tx = optax.adam(1e-3) + + def test_init_with_optimizer(self): + """Test init with iptimizer.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + self.assertEqual(state.model, self.model) + self.assertEqual(state.optimizer, optimizer) + # Access step directly from optimizer + self.assertEqual(state.optimizer.step.value, 0) + + def test_init_without_optimizer(self): + """Test init without optimizer.""" + state = train_state_nnx.TrainStateNNX(self.model, None) + + self.assertEqual(state.model, self.model) + self.assertIsNone(state.optimizer) + + def test_apply_gradients_success(self): + """Test apply gradients can be called successfully.""" + optimizer = nnx.Optimizer(self.model, self.tx, wrt=nnx.Param) + state = train_state_nnx.TrainStateNNX(self.model, optimizer) + + # Create dummy gradients matching the model state structure + def loss_fn(m): + return jnp.mean(m(jnp.ones((1, 2))) ** 2) + + grads = nnx.grad(loss_fn)(state.model) + + # Apply gradients + state.apply_gradients(grads) + + # Verify step incremented (managed by nnx.Optimizer) + self.assertEqual(state.optimizer.step.value, 1) + + def test_apply_gradients_raises_runtime_error(self): + """Test apply gradients without a optimizer.""" + # Initialize without optimizer (inference mode) + state = train_state_nnx.TrainStateNNX(self.model, None) + + dummy_grads = {} + with self.assertRaises(RuntimeError) as cm: + state.apply_gradients(dummy_grads) + + self.assertIn("inference only", str(cm.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/train_utils_nnx_test.py b/tests/unit/train_utils_nnx_test.py new file mode 100644 index 0000000000..2ff7276fd9 --- /dev/null +++ b/tests/unit/train_utils_nnx_test.py @@ -0,0 +1,149 @@ +# Copyright 2025-2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-specific helpers / patterns in train_utils.setup_train_loop. + +setup_train_loop itself is integration territory (it touches data iterators, +checkpoint managers, and a real mesh), so we cover the NNX-only pieces that +have unit-testable contracts: + + 1. The create_train_state_fn closure pattern: builds nnx.Optimizer + TrainStateNNX + from a zero-arg model factory and a transform. + 2. nnx.split(state.model, nnx.Param, ...) returns Param-only state used to + compute state_params / state_mesh_shardings_params. + 3. nnx.merge(state_graphdef, state) reconstitutes a TrainStateNNX from the + pure-state form returned by setup_training_state. +""" + +import unittest +from functools import partial + +import jax +import jax.numpy as jnp +import optax +from flax import nnx + +from maxtext.layers import train_state_nnx + + +class _Model(nnx.Module): + + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(2, 1, rngs=rngs) + + +class TestCreateTrainStateFnClosure(unittest.TestCase): + """Exercise the closure pattern in setup_train_loop: + + def create_train_state_fn(): + model = _create_model_partial() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + """ + + def test_returns_train_state_nnx_with_optimizer(self): + tx = optax.sgd(0.01) + + def _create_model(): + return _Model(rngs=nnx.Rngs(0)) + + def create_train_state_fn(): + model = _create_model() + optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(model, optimizer) + + state = create_train_state_fn() + self.assertIsInstance(state, train_state_nnx.TrainStateNNX) + self.assertIsInstance(state.optimizer, nnx.Optimizer) + self.assertEqual(int(state.optimizer.step.get_value()), 0) + + def test_two_invocations_produce_independent_states(self): + """The lambda must call the factory each time (otherwise checkpoint init/restore would alias).""" + tx = optax.sgd(0.01) + counter = {"n": 0} + + def _create_model(): + counter["n"] += 1 + return _Model(rngs=nnx.Rngs(counter["n"])) + + def create_train_state_fn(): + model = _create_model() + return train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, tx, wrt=nnx.Param)) + + s1 = create_train_state_fn() + s2 = create_train_state_fn() + self.assertEqual(counter["n"], 2) + self.assertIsNot(s1.model, s2.model) + + +class TestSetupTrainLoopNNXTreeOps(unittest.TestCase): + """Cover the nnx.split(state.model, nnx.Param, ...) and nnx.merge round-trip + patterns that setup_train_loop uses to derive Param-only views and rebuild + the full TrainStateNNX before returning.""" + + def setUp(self): + self.tx = optax.sgd(0.01) + self.model = _Model(rngs=nnx.Rngs(0)) + self.state = train_state_nnx.TrainStateNNX(self.model, nnx.Optimizer(self.model, self.tx, wrt=nnx.Param)) + + def test_nnx_split_yields_param_only_state(self): + """state_params used for assert_params_sufficiently_sharded must contain only nnx.Param leaves.""" + _, state_params, _ = nnx.split(self.state.model, nnx.Param, ...) + leaves = jax.tree.leaves(state_params, is_leaf=lambda x: isinstance(x, nnx.Variable)) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertIsInstance(leaf, nnx.Param) + + def test_nnx_merge_reconstructs_train_state_nnx(self): + """setup_train_loop ends with nnx.merge(state_graphdef, state) — verify that round-trips.""" + state_graphdef, state_pure = nnx.split(self.state) + train_state = nnx.merge(state_graphdef, state_pure) + self.assertIsInstance(train_state, train_state_nnx.TrainStateNNX) + # Same numeric values. + self.assertTrue(jnp.allclose(train_state.model.linear.kernel.value, self.state.model.linear.kernel.value)) + + +class TestInitStateFnIsCallable(unittest.TestCase): + """For the Linen path setup_train_loop builds init_state_fn = partial(...). + + The NNX path uses a closure instead — confirm both forms have the + zero-argument call contract create_checkpoint_manager / setup_training_state expect. + """ + + def test_nnx_init_state_fn_callable_with_no_args(self): + tx = optax.sgd(0.01) + + def _create_model(): + return _Model(rngs=nnx.Rngs(0)) + + def init_state_fn(): + model = _create_model() + return train_state_nnx.TrainStateNNX(model, nnx.Optimizer(model, tx, wrt=nnx.Param)) + + state = init_state_fn() # must not raise / require args + self.assertIsInstance(state, train_state_nnx.TrainStateNNX) + + def test_linen_init_state_fn_is_partial_callable_with_no_args(self): + """Sanity: the Linen-side `partial(init_initial_state, model, tx, config, is_training, init_rng)` form.""" + + def init_initial_state(model, tx, config, is_training, init_rng): + del model, tx, config, is_training, init_rng + return "linen-state" + + init_state_fn = partial(init_initial_state, "model", "tx", "config", True, "rng") + self.assertEqual(init_state_fn(), "linen-state") + + +if __name__ == "__main__": + unittest.main()