diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py index 7b670dd8d7..b8a513c1df 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_gpt3_ckpt_from_paxml.py @@ -40,6 +40,7 @@ import os import sys +from flax import nnx import jax from jax import random from jax.sharding import Mesh @@ -48,11 +49,15 @@ from maxtext.common import checkpointing from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.layers import quantizations +from maxtext.common import train_state_nnx from maxtext.models.models import transformer_as_linen from maxtext.optimizers import optimizers from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils +from maxtext.utils import train_utils import numpy as np from psutil import Process import tensorstore as ts @@ -87,12 +92,23 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name devices_array = maxtext_utils.create_device_mesh(cfg) mesh = Mesh(devices_array, cfg.mesh_axes) - # Output is Linen-format (keystr_map below uses Linen tree paths). Route to - # Linen regardless of pure_nnx. - quant = quantizations.configure_quantization(cfg) - model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) - learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) - tx = optimizers.get_optimizer(cfg, learning_rate_schedule) + if cfg.pure_nnx: + rngs = maxtext_utils_nnx.create_nnx_rngs(cfg, rng_key=init_rng) + model = model_creation_utils.from_config(cfg, mesh=mesh, rngs=rngs) + _, tx = train_utils.create_training_optimizer(cfg, model) + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(cfg, mesh) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + else: + quant = quantizations.configure_quantization(cfg) + model = transformer_as_linen(cfg, mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(cfg) + tx = optimizers.get_optimizer(cfg, learning_rate_schedule) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( cfg.checkpoint_dir, @@ -101,7 +117,6 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name cfg.checkpoint_period, ) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, cfg, True, init_rng) state, _, _, _ = maxtext_utils.setup_training_state(None, cfg, mesh, checkpoint_manager, init_state_fn) max_logging.log("start") max_utils.print_mem_stats("After params initialized") @@ -186,10 +201,24 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name "['decoder']['decoder_norm']['bias']": (".params.lm.final_ln.bias", None), } - state_map = { - ".step": ("step", None), - ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), - } + if cfg.pure_nnx: + # NNX state-tree paths after `nnx.split(TrainStateNNX)`. The state is a + # nested `nnx.State` (dict-like Mapping) with `nnx.Variable` leaves, so + # `jax.tree_util.keystr` produces dict-style entries (`['key']`) plus + # `.value` for the Variable leaf, plus `[idx]` for the optax tuple: + # model params -> ['model'].value + # adam mu / nu -> ['optimizer']['opt_state'][0]['mu' | 'nu'].value + # step -> ['optimizer']['step'].value + # opt count -> ['optimizer']['opt_state'][0]['count'].value + state_map = { + "['optimizer']['step'].value": ("step", None), + "['optimizer']['opt_state'][0]['count'].value": ("opt_states_0.no_prefix_0.count", None), + } + else: + state_map = { + ".step": ("step", None), + ".opt_state.count": ("opt_states_0.no_prefix_0.count", None), + } def get_layer_prefix(keystr_pax): # different path format between decoder_layer variable @@ -201,19 +230,27 @@ def get_layer_prefix(keystr_pax): return prefix_pax_opt_state for keystr_maxtext, (keystr_pax, transform_fn) in keystr_map.items(): - # model variable - state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn) prefix_pax_opt_state = get_layer_prefix(keystr_pax) - # first momentum in optimizer state - state_map[f".opt_state.mu['params']{keystr_maxtext}"] = ( - f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", - transform_fn, - ) - # second momentum in optimizer state - state_map[f".opt_state.nu['params']{keystr_maxtext}"] = ( - f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", - transform_fn, - ) + if cfg.pure_nnx: + state_map[f"['model']{keystr_maxtext}.value"] = (f"mdl_vars{keystr_pax}", transform_fn) + state_map[f"['optimizer']['opt_state'][0]['mu']{keystr_maxtext}.value"] = ( + f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", + transform_fn, + ) + state_map[f"['optimizer']['opt_state'][0]['nu']{keystr_maxtext}.value"] = ( + f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", + transform_fn, + ) + else: + state_map[f".params['params']{keystr_maxtext}"] = (f"mdl_vars{keystr_pax}", transform_fn) + state_map[f".opt_state.mu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.m{keystr_pax}", + transform_fn, + ) + state_map[f".opt_state.nu['params']{keystr_maxtext}"] = ( + f"opt_states_0.{prefix_pax_opt_state}.v{keystr_pax}", + transform_fn, + ) def verify_fn(key_path, _): keystr = jax.tree_util.keystr(key_path) @@ -265,10 +302,11 @@ def map_fn(key_path, value): max_logging.log("converted state finished") max_utils.print_mem_stats("converted state finished") - if checkpointing.save_checkpoint(checkpoint_manager, converted_state.step, converted_state): - max_logging.log(f"saved a checkpoint at step {converted_state.step}") + step_value = int(converted_state.optimizer.step.value) if cfg.pure_nnx else converted_state.step + if checkpointing.save_checkpoint(checkpoint_manager, step_value, converted_state): + max_logging.log(f"saved a checkpoint at step {step_value}") # Upon preemption, exit when and only when all ongoing saves are complete. - if checkpoint_manager.reached_preemption(converted_state.step): + if checkpoint_manager.reached_preemption(step_value): checkpoint_manager.wait_until_finished() sys.exit() diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index 282cb31a54..f6038ae018 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -486,15 +486,22 @@ def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data): state = nnx.merge(model_graphdef, state) # Reconstruct the TrainStateNNX. policy_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) - - def diff_wrapper(param, rest, config, data): + # Split the reference model into (graphdef, state) so we pass `ref_state` as + # an explicit pytree-typed argument to `diff_wrapper` instead of closing over + # the mutable nnx.Module — closure capture inside jax.value_and_grad works + # only by accident (Modules aren't registered JAX pytrees) and breaks the + # moment the reference forward touches any internal state. + ref_graphdef, ref_state = nnx.split(state.reference_model) + + def diff_wrapper(param, rest, ref_state, config, data): local_model = nnx.merge(policy_graphdef, param, rest, copy=True) - loss, aux = grpo_loss_fn_nnx(local_model, config, data, None, None, state.reference_model, is_train=True) + local_ref = nnx.merge(ref_graphdef, ref_state, copy=True) + loss, aux = grpo_loss_fn_nnx(local_model, config, data, None, None, local_ref, is_train=True) _, _, new_rest = nnx.split(local_model, nnx.Param, ...) return loss, (aux, new_rest) grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) - (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, ref_state, config, data) nnx.update(state.model, new_rest) if config.gradient_clipping_threshold > 0: @@ -798,8 +805,11 @@ def init_state_fn(): optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) # Reference uses the same init seed so it starts identical to the policy. reference_model = _create_model_partial() - # pylint: disable-next=unexpected-keyword-arg - return train_state_nnx.TrainStateNNX(nnx_model, optimizer, reference_model=reference_model) + # TrainStateNNX only takes (model, optimizer); reference_model is an NNX + # sibling attribute set after construction (nnx.Module is mutable). + state = train_state_nnx.TrainStateNNX(nnx_model, optimizer) + state.reference_model = reference_model + return state else: init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 9a13849248..6f7c777431 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -499,11 +499,11 @@ def move(path, value): "learning/total_weights": total_weights, } if config.use_qk_clip: - # Apply QK-Clip (Linen path only; NNX uses different state layout — TODO: implement for NNX) if isinstance(model, nn.Module): new_state = qk_clip_utils.apply_qk_clip(new_state, intermediate_outputs, config) + else: + new_state = qk_clip_utils.apply_qk_clip_nnx(new_state, intermediate_outputs, config) - # Report max_logits metric global_max_logit = qk_clip_utils.calculate_max_logit_metric(intermediate_outputs) if global_max_logit is not None: scalar_metrics["learning/max_logits"] = global_max_logit diff --git a/src/maxtext/utils/generate_param_only_checkpoint.py b/src/maxtext/utils/generate_param_only_checkpoint.py index 7661da296f..2b2f3f0dde 100644 --- a/src/maxtext/utils/generate_param_only_checkpoint.py +++ b/src/maxtext/utils/generate_param_only_checkpoint.py @@ -28,13 +28,16 @@ from absl import app from etils import epath +from flax import nnx import jax +import jax.numpy as jnp from jax import random from jax.sharding import Mesh from maxtext.configs import pyconfig from maxtext.common import checkpointing from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_TRAIN from maxtext.layers import quantizations +from maxtext.common import train_state_nnx from maxtext.models import models from maxtext.optimizers import optimizers from maxtext.utils import gcs_utils @@ -42,12 +45,18 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils +from maxtext.utils import train_utils def _possibly_unroll_params(config, training_state, training_state_annotations, mesh): """Unroll scanned input layers when force_unroll is set.""" if not config.scan_layers or not config.force_unroll: return + if config.pure_nnx: + _possibly_unroll_params_nnx(config, training_state, training_state_annotations, mesh) + return def unroll_layer_group(num_layers, layer_name="layers"): """Helper function to unroll layers (e.g. dense or MoE) into individual layers.""" @@ -88,19 +97,85 @@ def slice_ith(input_layers): unroll_layer_group(config.num_decoder_layers, layer_name="layers") +def _possibly_unroll_params_nnx(config, state, state_mesh_shardings, mesh): + """NNX equivalent of _possibly_unroll_params. + + `state` is a flat `nnx.State` (post-split TrainStateNNX) with `state.model` + as a sub-State whose tree mirrors the model module hierarchy. Slices + `state.model.decoder[layer_name]` into per-index `layer_name_0..N` siblings + and removes the original collection. Mirrors the same operation on + `state_mesh_shardings` so downstream sharding stays correct. + """ + decoder_state = state.model.decoder + decoder_shardings = state_mesh_shardings.model.decoder + + def unroll_layer_group(num_layers, layer_name="layers"): + layers = decoder_state.get(layer_name, None) + layers_shardings = decoder_shardings.get(layer_name, None) + if layers is None or layers_shardings is None: + raise ValueError(f"Missing {layer_name} in NNX state.model.decoder or state_mesh_shardings.") + + def drop_scan_axis(named_sharding): + ps = named_sharding.spec + return jax.sharding.PartitionSpec(*(ps[0 : config.param_scan_axis] + ps[config.param_scan_axis + 1 :])) + + new_layer_pspec = jax.tree_util.tree_map( + drop_scan_axis, layers_shardings, is_leaf=lambda x: isinstance(x, jax.sharding.NamedSharding) + ) + new_layer_sharding = jax.tree_util.tree_map(lambda ps: jax.sharding.NamedSharding(mesh, ps), new_layer_pspec) + + for i in range(num_layers): + + def slice_ith(input_layers): + return jax.tree_util.tree_map(lambda x: jnp.take(x, i, axis=config.param_scan_axis), input_layers) + + # pylint: disable=not-callable + new_layer = jax.jit(slice_ith, out_shardings=new_layer_sharding)(layers) + + decoder_state[f"{layer_name}_{i}"] = new_layer + decoder_shardings[f"{layer_name}_{i}"] = new_layer_sharding + + decoder_state.pop(layer_name) + decoder_shardings.pop(layer_name) + jax.tree_util.tree_map(lambda x: x.delete() if hasattr(x, "delete") else None, layers) + + if config.decoder_block == DecoderBlockType.DEEPSEEK: + unroll_layer_group(config.first_num_dense_layers, layer_name="dense_layers") + unroll_layer_group(config.num_decoder_layers - config.first_num_dense_layers, layer_name="moe_layers") + else: + unroll_layer_group(config.num_decoder_layers, layer_name="layers") + + def _read_train_checkpoint(config, checkpoint_manager, mesh): """Read training checkpoint at path defined by load_full_state_path.""" - # Input and output are both Linen-format (downstream uses Linen tree paths). - # Route to Linen regardless of pure_nnx. - quant = quantizations.configure_quantization(config) - model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) - learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) - tx = optimizers.get_optimizer(config, learning_rate_schedule) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) - state, state_mesh_notations, _, _ = maxtext_utils.setup_training_state( + if config.pure_nnx: + rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=rng) + model = model_creation_utils.from_config(config, mesh=mesh, rngs=rngs) + _, tx = train_utils.create_training_optimizer(config, model) + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + else: + quant = quantizations.configure_quantization(config) + model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) + learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + + state, state_mesh_notations, state_mesh_shardings, _ = maxtext_utils.setup_training_state( None, config, mesh, checkpoint_manager, init_state_fn ) + if config.pure_nnx: + # On NNX, state is a flat nnx.State; params live under state.model and the + # legacy notations are unused (callers receive shardings directly). + num_params = max_utils.calculate_num_params_from_pytree(state.model) + max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") + return state, state_mesh_shardings num_params = max_utils.calculate_num_params_from_pytree(state.params) max_logging.log(f"In input checkpoint Number of model params={num_params/1e9:.3f} billion") return state, state_mesh_notations @@ -108,8 +183,9 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh): def _generate_lora_decode_checkpoints(config, mesh): """Read lora checkpoints checkpoint at path defined by load_full_state_path.""" - # LoRA adapters and decode checkpoints are both Linen-format (downstream uses Linen tree paths). - # Route to Linen regardless of pure_nnx. + if config.pure_nnx: + _generate_lora_decode_checkpoints_nnx(config, mesh) + return quant = quantizations.configure_quantization(config) model = models.transformer_as_linen(config, mesh, quant, MODEL_MODE_TRAIN) rng = random.PRNGKey(0) @@ -148,6 +224,9 @@ def _generate_lora_decode_checkpoints(config, mesh): def _save_decode_checkpoint(config, state, checkpoint_manager): """Generate checkpoint for decode from the training_state.""" + if config.pure_nnx: + _save_decode_checkpoint_nnx(config, state, checkpoint_manager) + return decode_state = maxtext_utils.init_decode_state( None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params) ) @@ -157,6 +236,121 @@ def _save_decode_checkpoint(config, state, checkpoint_manager): checkpoint_manager.wait_until_finished() +def _save_decode_checkpoint_nnx(config, state, checkpoint_manager): + """Save a bf16 NNX-format param-only decode checkpoint. + + The on-disk shape mirrors what a vanilla NNX-trained checkpoint produces: a + plain dict tree of arrays (one per nnx.Param), with no Linen-style "params" + wrapper. This is the shape `from_pretrained` reads via its NNX-detection + branch (see model_creation_utils._adjust_target_for_moe_fusion / "is_nnx_checkpoint"). + """ + pure_model = state.model.to_pure_dict() if hasattr(state.model, "to_pure_dict") else dict(state.model) + bf16_model = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pure_model) + if checkpoint_manager is not None: + if checkpointing.save_checkpoint(checkpoint_manager, 0, bf16_model): + max_logging.log(f"saved an NNX decode checkpoint at {config.checkpoint_dir}") + checkpoint_manager.wait_until_finished() + + +def _possibly_unroll_lora_params_nnx(config, lora_state, lora_state_annotations, mesh): + """Unroll scanned LoRA delta layers when force_unroll is set on the NNX path. + + `lora_state` is a Linen-style `TrainState` (returned by `get_lora_abstract_state_nnx`) + whose `.params` is single-nested (`{"decoder": {...}}`, no outer `params` wrap) + and whose leaves at target attention paths are `lora_a.kernel`/`lora_b.kernel`. + """ + if not config.scan_layers or not config.force_unroll: + return + + decoder_params = lora_state.params["decoder"] + decoder_annotations = lora_state_annotations.params["decoder"] + + def unroll_layer_group(num_layers, layer_name="layers"): + layers = decoder_params.get(layer_name) + layers_annotations = decoder_annotations.get(layer_name) + if layers is None or layers_annotations is None: + return # No LoRA on this layer group; nothing to unroll. + + def new_pspec(x): + return jax.sharding.PartitionSpec(*(x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :])) + + new_layer_annotation = jax.tree_util.tree_map(new_pspec, layers_annotations) + new_layer_sharding = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_layer_annotation) + + for i in range(num_layers): + + def slice_ith(input_layers): + return jax.tree_util.tree_map(lambda x: jnp.take(x, i, axis=config.param_scan_axis), input_layers) + + # pylint: disable=not-callable + new_layer = jax.jit(slice_ith, out_shardings=new_layer_sharding)(layers) + decoder_params[f"{layer_name}_{i}"] = new_layer + decoder_annotations[f"{layer_name}_{i}"] = new_layer_annotation + + del decoder_params[layer_name] + del decoder_annotations[layer_name] + jax.tree_util.tree_map(lambda x: x.delete() if hasattr(x, "delete") else None, layers) + + if config.decoder_block == DecoderBlockType.DEEPSEEK: + unroll_layer_group(config.first_num_dense_layers, layer_name="dense_layers") + unroll_layer_group(config.num_decoder_layers - config.first_num_dense_layers, layer_name="moe_layers") + else: + unroll_layer_group(config.num_decoder_layers, layer_name="layers") + + +def _save_lora_decode_checkpoint_nnx(config, lora_state, checkpoint_manager): + """Save a bf16 LoRA-only decode checkpoint (NNX path). + + `lora_state.params` is single-nested (NNX-derived shape). The on-disk + format mirrors the Linen LoRA decode shape so existing serving consumers + can keep reading it: a `TrainState` wrapper with `params` set to the + bf16-cast LoRA delta tree. The base model is loaded separately at serve + time via `apply_lora_on_base_params_nnx`. + """ + decode_state = maxtext_utils.init_decode_state( + None, jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), lora_state.params) + ) + if checkpoint_manager is not None: + if checkpointing.save_checkpoint(checkpoint_manager, 0, decode_state): + max_logging.log(f"saved a LoRA decode checkpoint at {config.checkpoint_dir}") + checkpoint_manager.wait_until_finished() + + +def _generate_lora_decode_checkpoints_nnx(config, mesh): + """NNX-shaped sibling of `_generate_lora_decode_checkpoints`. + + Builds the NNX abstract base model so `setup_initial_lora_state` + produces an NNX-derived `lora_state`, then runs an NNX-shape unroll/save. + """ + rng = random.PRNGKey(0) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=rng) + model = model_creation_utils.from_config(config, mesh=mesh, rngs=rngs) + _, tx = train_utils.create_training_optimizer(config, model) + + lora_adapters = gcs_utils.gcs_list_directories(config.lora_input_adapters_path) + for lora_id in lora_adapters: + lora_checkpoint_dir = os.path.join(config.checkpoint_dir, "loras", lora_id, "") + lora_adapter_path = os.path.join(config.lora_input_adapters_path, lora_id, "") + + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + lora_checkpoint_dir, + config.enable_checkpointing, + config.async_checkpointing, + config.checkpoint_period, + ) + + lora_config, lora_state, lora_state_annotations = lora_utils.setup_initial_lora_state( + model, None, tx, config, rng, mesh, checkpoint_manager, lora_adapter_path + ) + + _possibly_unroll_lora_params_nnx(config, lora_state, lora_state_annotations, mesh) + + gcs_utils.write_dict_to_gcs_json(lora_config, os.path.join(lora_checkpoint_dir, "adapter_config.json")) + + _save_lora_decode_checkpoint_nnx(config, lora_state, checkpoint_manager) + max_logging.log(f"Successfully saved LoRA checkpoint at: {os.path.join(lora_checkpoint_dir, '0', 'items')}") + + def generate_decode_checkpoint(config): """ Generate an decode checkpoint from a given training checkpoint. diff --git a/src/maxtext/utils/qk_clip_utils.py b/src/maxtext/utils/qk_clip_utils.py index 64848b8ffb..d3a7b926e4 100644 --- a/src/maxtext/utils/qk_clip_utils.py +++ b/src/maxtext/utils/qk_clip_utils.py @@ -16,6 +16,7 @@ import jax import jax.numpy as jnp +from flax import nnx def _get_key_name(k): @@ -30,132 +31,150 @@ def _get_key_name(k): def calculate_max_logit_metric(intermediate_outputs): """Extracts and computes the global maximum logit from intermediate outputs. - Args: - intermediate_outputs: A pytree containing model intermediates, potentially - including 'max_logits' sowed by Attention layers. + Recognizes two shapes: Linen sow stores `(array,)` so the leaf path ends in + `max_logits, 0`; NNX `nnx.Intermediate(array)` stores the array directly so + the leaf path ends in `max_logits`. - Returns: - The global maximum logit scalar, or None if no logits were found. + Returns the global max scalar, or None if no logits were found. """ all_max_logits = [] def extract_logits(path, val): - # 'sow' stores values in a tuple/list. tree_map descends into it. - # The path to the leaf array will look like: (..., 'max_logits', 0) - # So we check if the parent key (path[-2]) is 'max_logits'. - if len(path) >= 2: - parent_key = _get_key_name(path[-2]) - if parent_key == "max_logits": - all_max_logits.append(val) + if not path: + return + last_key = _get_key_name(path[-1]) + parent_key = _get_key_name(path[-2]) if len(path) >= 2 else None + if last_key == "max_logits" or parent_key == "max_logits": + all_max_logits.append(val) jax.tree_util.tree_map_with_path(extract_logits, intermediate_outputs) if not all_max_logits: return None - # Compute max per layer first to handle potential shape mismatches return jnp.max(jnp.stack([jnp.max(x) for x in all_max_logits])) -def apply_qk_clip(state, intermediate_outputs, config): - """Applies QK-Clip to MLA weights based on max_logits. - - Iterates over parameters. If a parameter belongs to an MLA attention layer, - it finds the corresponding max_logits statistics from intermediate_outputs, - calculates the clipping factor, and applies it to W_q and W_k components. - - Args: - state: The current training state containing model parameters. - intermediate_outputs: A dictionary of intermediate outputs from the model - forward pass. It is expected to contain 'max_logits' entries sowed by - Attention layers if QK-Clip is enabled. - config: The model configuration object, containing QK-Clip hyperparameters - (e.g. qk_clip_threshold, qk_nope_head_dim) and attention_type. - - Returns: - A new training state with updated (clipped) parameters. - - Raises: - ValueError: If the configured attention_type is not 'mla'. - """ +def _check_attention_type(config): if getattr(config, "attention_type", None) != "mla": raise ValueError( f"QK-Clip is only supported for MLA attention (attention_type='mla'). " f"Current configuration: {getattr(config, 'attention_type', 'None')}" ) - tau = float(config.qk_clip_threshold) - def clip_mla_weights(path, param): - """Applies QK-Clip to a single parameter if it's an MLA projection weight. +def _max_logits_at(curr): + """Read max_logits from a node in the intermediates tree. + + Returns the [batch, num_heads] array, or None if not present. Handles both + the Linen sow shape (`{"max_logits": (array,)}`) and the NNX shape + (`{"max_logits": array}` or `{"attention_op": {"max_logits": array}}`). + """ + if not isinstance(curr, dict): + return None + ml = curr.get("max_logits") + if ml is None and "attention_op" in curr and isinstance(curr["attention_op"], dict): + ml = curr["attention_op"].get("max_logits") + if ml is None: + return None + if isinstance(ml, (tuple, list)): + return ml[0] if ml else None + return ml - Args: - path: A tuple of JAX Key objects representing the hierarchy path to the parameter in the state PyTree. - param: The actual JAX array (weight tensor) at the given path. - Returns: - The scaled parameter if it is an MLA projection ('wq_b' or 'wkv_b'), otherwise the original parameter. - """ - # Skip irrelevant weights (embeddings, norms, etc.). - # We only care about specific MLA projection matrices ('wq_b', 'wkv_b'). +def _scale_from_max_logits(max_logits_batch, tau): + s_max = jnp.max(max_logits_batch, axis=0) + return jnp.minimum(1.0, tau / (s_max + 1e-6)) + + +def _clip_mla_weight(layer_name, param, scale, qk_nope): + """Apply the per-head scale to a wq_b or wkv_b kernel.""" + scale_b = scale[None, :, None] # broadcasts over [rank, heads, dim] + head = param[..., :qk_nope] + tail = param[..., qk_nope:] + head_new = head * jnp.sqrt(scale_b) + if layer_name == "wq_b": + tail_new = tail * scale_b + else: # wkv_b: tail is the V slice, untouched + tail_new = tail + return jnp.concatenate([head_new, tail_new], axis=-1) + + +def apply_qk_clip(state, intermediate_outputs, config): + """Applies QK-Clip to MLA weights based on max_logits (Linen path). + + Returns a new TrainState with `wq_b`/`wkv_b` kernels rescaled per-head. + """ + _check_attention_type(config) + tau = float(config.qk_clip_threshold) + + def clip_mla_weights(path, param): if len(path) < 2: return param - layer_name = _get_key_name(path[-2]) if layer_name not in ("wq_b", "wkv_b"): return param - # Search for max_logits in intermediate_outputs curr = intermediate_outputs.get("intermediates", intermediate_outputs) for node in path[:-2]: key = _get_key_name(node) if isinstance(curr, dict) and key in curr: curr = curr[key] else: - return param # Path not found in intermediates, skip + return param - if not isinstance(curr, dict) or "max_logits" not in curr: + max_logits_batch = _max_logits_at(curr) + if max_logits_batch is None: return param - # max_logits was sowed as a tuple (array,) - # shape: [batch, num_heads] - max_logits_sowed = curr["max_logits"] - if not max_logits_sowed: - return param + scale = _scale_from_max_logits(max_logits_batch, tau) + return _clip_mla_weight(layer_name, param, scale, config.qk_nope_head_dim) - max_logits_batch = max_logits_sowed[0] - - # Calculate S_max (per head) - # We want the global maximum across the batch dimension. - # Result shape: [num_heads] - s_max = jnp.max(max_logits_batch, axis=0) - - # Calculate scaling factor gamma - # gamma = tau / s_max. Clip if s_max > tau. - scale = jnp.minimum(1.0, tau / (s_max + 1e-6)) - - # Apply qk clipping based on weight type - if layer_name == "wq_b": - # MLA Up-projection for Query [rank, heads, q_head_dim] - qk_nope = config.qk_nope_head_dim - w_qc = param[..., :qk_nope] - w_qr = param[..., qk_nope:] - scale_b = scale[None, :, None] # Broadcast: [1, heads, 1] - w_qc_new = w_qc * jnp.sqrt(scale_b) - w_qr_new = w_qr * scale_b - return jnp.concatenate([w_qc_new, w_qr_new], axis=-1) - - elif layer_name == "wkv_b": - # MLA Up-projection for Key/Value [rank, heads, kv_head_dim] - qk_nope = config.qk_nope_head_dim - w_kc = param[..., :qk_nope] - w_v = param[..., qk_nope:] - scale_b = scale[None, :, None] - w_kc_new = w_kc * jnp.sqrt(scale_b) - return jnp.concatenate([w_kc_new, w_v], axis=-1) - - return param - - # Apply transformation new_params = jax.tree_util.tree_map_with_path(clip_mla_weights, state.params) return state.replace(params=new_params) + + +def apply_qk_clip_nnx(state, intermediate_outputs, config): + """Applies QK-Clip to MLA weights on an NNX TrainStateNNX. + + `state.model` is mutated in place (NNX modules are mutable). Returns `state` + so call sites can use the same `new_state = apply_qk_clip(...)` pattern as + the Linen path. + + The intermediates tree mirrors the NNX module hierarchy, so `max_logits` + sowed by `AttentionOp` lives at `...self_attention.attention_op.max_logits`. + We accept either that shape or `...self_attention.max_logits` (matching the + Linen-side fixtures and small-test setups). + """ + _check_attention_type(config) + tau = float(config.qk_clip_threshold) + + _, params_state, _ = nnx.split(state.model, nnx.Param, ...) + params_dict = params_state.to_pure_dict() + + def clip_mla_weights(path, param): + if len(path) < 2: + return param + layer_name = _get_key_name(path[-2]) + if layer_name not in ("wq_b", "wkv_b"): + return param + + curr = intermediate_outputs + for node in path[:-2]: + key = _get_key_name(node) + if isinstance(curr, dict) and key in curr: + curr = curr[key] + else: + return param + + max_logits_batch = _max_logits_at(curr) + if max_logits_batch is None: + return param + + scale = _scale_from_max_logits(max_logits_batch, tau) + return _clip_mla_weight(layer_name, param, scale, config.qk_nope_head_dim) + + new_params_dict = jax.tree_util.tree_map_with_path(clip_mla_weights, params_dict) + nnx.replace_by_pure_dict(params_state, new_params_dict) + nnx.update(state.model, params_state) + return state diff --git a/src/maxtext/utils/standalone_checkpointer.py b/src/maxtext/utils/standalone_checkpointer.py index 893fdc531a..6b1aa264c2 100644 --- a/src/maxtext/utils/standalone_checkpointer.py +++ b/src/maxtext/utils/standalone_checkpointer.py @@ -24,15 +24,19 @@ from typing import Sequence from absl import app +from flax import nnx from flax.linen import partitioning as nn_partitioning import jax from jax import numpy as jnp from maxtext.configs import pyconfig from maxtext.common import checkpointing +from maxtext.common import train_state_nnx from maxtext.models import models from maxtext.trainers.pre_train.train import get_first_step from maxtext.utils import max_logging from maxtext.utils import maxtext_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils from maxtext.utils import train_utils from maxtext.utils.model_creation_utils import from_config import numpy as np @@ -41,24 +45,30 @@ def checkpoint_loop(config, state=None): - """Main Checkpointing loop. + """Save/restore exerciser. - Saves checkpoints. - - Args: - config: - state: - ckpt_path: - - Returns: + Builds an abstract train state, restores or initializes it, perturbs the + optimizer moments via `add_entropy_to_checkpoint`, then writes checkpoints + on the configured cadence. Works on both Linen and NNX state shapes. """ - # Save/restore exerciser uses Linen-shaped optimizer state via - # add_entropy_to_checkpoint(). Route to Linen regardless of pure_nnx. - model = from_config(config) - mesh = model.mesh init_rng = jax.random.PRNGKey(config.init_weights_seed) - _, tx = train_utils.create_training_optimizer(config, model) - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + if config.pure_nnx: + mesh = maxtext_utils.get_mesh_from_config(config) + rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=init_rng) + model = from_config(config, mesh=mesh, rngs=rngs) + _, tx = train_utils.create_training_optimizer(config, model) + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + else: + model = from_config(config) + mesh = model.mesh + _, tx = train_utils.create_training_optimizer(config, model) + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, is_training=True) @@ -108,22 +118,38 @@ def checkpoint_loop(config, state=None): def add_entropy_to_checkpoint(state): - """Introduce randomness in checkpoints. - - This is useful to simulate real checkpoints, without training. - - Args: - state: Initial state - - Returns: - state: Returns state with entropy added to the optimizer state. + """Replace adam mu/nu with cos/sin of params. + + Stand-in for real training when exercising checkpoint save/restore. Handles + three shapes: + * Linen `TrainState`: `state.params` + `state.opt_state` (tuple). + * NNX `TrainStateNNX` (Module): `state.model` is an `nnx.Module`; the + optimizer's `opt_state` is the optax tuple of NamedTuples. + * NNX `nnx.State` (post-split, what `setup_training_state` returns under + `pure_nnx`): `state.model` and `state.optimizer.opt_state` are sub-States; + `opt_state[0].mu`/`nu` are themselves States that can be reassigned. """ + if hasattr(state, "model"): + if isinstance(state, nnx.Module): + params = nnx.state(state.model, nnx.Param) + else: + params = state.model.filter(nnx.Param) if hasattr(state.model, "filter") else state.model + new_mu = jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), params) + new_nu = jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), params) + + if isinstance(state, nnx.Module): + opt = state.optimizer + opt.opt_state = (opt.opt_state[0]._replace(mu=new_mu, nu=new_nu),) + tuple(opt.opt_state[1:]) + else: + state.optimizer.opt_state[0].mu = new_mu + state.optimizer.opt_state[0].nu = new_nu + return state + opt_0 = state.opt_state[0] opt_0 = opt_0._replace(mu=jax.tree_util.tree_map(lambda k: jnp.cos(1000 * k), state.params)) opt_0 = opt_0._replace(nu=jax.tree_util.tree_map(lambda k: jnp.sin(1000 * k), state.params)) new_opt = [opt_0] + list(state.opt_state[1:]) - state = state.replace(opt_state=new_opt) - return state + return state.replace(opt_state=new_opt) def main(argv: Sequence[str]) -> None: diff --git a/tests/unit/generate_param_only_checkpoint_nnx_test.py b/tests/unit/generate_param_only_checkpoint_nnx_test.py new file mode 100644 index 0000000000..57995a6145 --- /dev/null +++ b/tests/unit/generate_param_only_checkpoint_nnx_test.py @@ -0,0 +1,205 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX path of generate_param_only_checkpoint. + +Covers `_possibly_unroll_params_nnx` (slicing scanned NNX layers) and the +shape parity of `_save_decode_checkpoint_nnx`'s bf16 cast. +""" + +from types import SimpleNamespace +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax import nnx +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from flax.training import train_state as linen_train_state + +from maxtext.common.common_types import DecoderBlockType +from maxtext.common import train_state_nnx +from maxtext.utils.generate_param_only_checkpoint import ( + _possibly_unroll_lora_params_nnx, + _possibly_unroll_params_nnx, +) + + +class _ScanLayerLeaf(nnx.Module): + """One scanned-layer kernel with leading shape `[num_layers, *]`.""" + + def __init__(self, num_layers: int, in_dim: int, out_dim: int): + self.kernel = nnx.Param( + jnp.arange(num_layers * in_dim * out_dim, dtype=jnp.float32).reshape(num_layers, in_dim, out_dim) + ) + + +class _Decoder(nnx.Module): + + def __init__(self, num_layers: int): + self.layers = _ScanLayerLeaf(num_layers, 3, 5) + + +class _Model(nnx.Module): + + def __init__(self, num_layers: int): + self.decoder = _Decoder(num_layers) + + +def _make_split_state(num_layers: int): + model = _Model(num_layers) + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + _, state = nnx.split(ts) + return state + + +def _make_shardings_state(state, mesh): + """Build a sibling shardings tree where each Variable is replaced by NamedSharding(replicated).""" + + def to_named(v): + return NamedSharding(mesh, PartitionSpec()) + + return jax.tree_util.tree_map(to_named, state, is_leaf=lambda x: isinstance(x, nnx.Variable)) + + +class PossiblyUnrollParamsNNXTest(unittest.TestCase): + + def setUp(self): + devices = np.array(jax.devices()).reshape(-1) + self.mesh = Mesh(devices, ("data",)) + + def test_unrolls_scanned_layers(self): + num_layers = 3 + state = _make_split_state(num_layers) + shardings = _make_shardings_state(state, self.mesh) + + original_kernel = np.asarray(state.model.decoder.layers.kernel[...]) + + config = SimpleNamespace( + scan_layers=True, + force_unroll=True, + pure_nnx=True, + param_scan_axis=0, + decoder_block=DecoderBlockType.LLAMA2, + num_decoder_layers=num_layers, + ) + + _possibly_unroll_params_nnx(config, state, shardings, self.mesh) + + self.assertNotIn("layers", state.model.decoder) + self.assertNotIn("layers", shardings.model.decoder) + for i in range(num_layers): + self.assertIn(f"layers_{i}", state.model.decoder) + self.assertIn(f"layers_{i}", shardings.model.decoder) + sliced = state.model.decoder[f"layers_{i}"]["kernel"][...] + expected = jnp.take(original_kernel, i, axis=0) + self.assertTrue(jnp.array_equal(sliced, expected)) + + def test_deepseek_split(self): + """DeepSeek decoder has separate dense/moe layer collections.""" + + # Build a DeepSeek-flavored synthetic model with two scanned groups. + class _DeepSeekDecoder(nnx.Module): + + def __init__(self): + self.dense_layers = _ScanLayerLeaf(2, 3, 5) + self.moe_layers = _ScanLayerLeaf(3, 3, 5) + + class _DSModel(nnx.Module): + + def __init__(self): + self.decoder = _DeepSeekDecoder() + + model = _DSModel() + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) + ts = train_state_nnx.TrainStateNNX(model, optimizer) + _, state = nnx.split(ts) + shardings = _make_shardings_state(state, self.mesh) + + config = SimpleNamespace( + scan_layers=True, + force_unroll=True, + pure_nnx=True, + param_scan_axis=0, + decoder_block=DecoderBlockType.DEEPSEEK, + num_decoder_layers=5, + first_num_dense_layers=2, + ) + + _possibly_unroll_params_nnx(config, state, shardings, self.mesh) + + self.assertNotIn("dense_layers", state.model.decoder) + self.assertNotIn("moe_layers", state.model.decoder) + for i in range(2): + self.assertIn(f"dense_layers_{i}", state.model.decoder) + for i in range(3): + self.assertIn(f"moe_layers_{i}", state.model.decoder) + + +class PossiblyUnrollLoraParamsNNXTest(unittest.TestCase): + """The LoRA delta tree is single-nested (`{"decoder": {...}}`) and held in a + Linen `TrainState` even on the NNX path — the unroll has to walk that shape.""" + + def setUp(self): + devices = np.array(jax.devices()).reshape(-1) + self.mesh = Mesh(devices, ("data",)) + + def _make_lora_state(self, num_layers: int, lora_rank: int = 4): + """Build a synthetic LoRA delta TrainState mirroring `get_lora_abstract_state_nnx`'s output shape.""" + lora_a = jnp.arange(num_layers * 8 * lora_rank, dtype=jnp.float32).reshape(num_layers, 8, lora_rank) + lora_b = jnp.arange(num_layers * lora_rank * 4 * 2, dtype=jnp.float32).reshape(num_layers, lora_rank, 4, 2) + params = { + "decoder": { + "layers": { + "self_attention": { + "query": {"lora_a.kernel": lora_a, "lora_b.kernel": lora_b}, + } + } + } + } + annotations_params = jax.tree_util.tree_map(lambda _: PartitionSpec(), params) + state = linen_train_state.TrainState(step=0, apply_fn=None, params=params, tx=None, opt_state={}) + annotations = linen_train_state.TrainState(step=0, apply_fn=None, params=annotations_params, tx=None, opt_state={}) + return state, annotations + + def test_unrolls_scanned_lora_layers(self): + num_layers = 3 + state, annotations = self._make_lora_state(num_layers) + original_a = np.asarray(state.params["decoder"]["layers"]["self_attention"]["query"]["lora_a.kernel"]) + + config = SimpleNamespace( + scan_layers=True, + force_unroll=True, + pure_nnx=True, + param_scan_axis=0, + decoder_block=DecoderBlockType.LLAMA2, + num_decoder_layers=num_layers, + ) + + _possibly_unroll_lora_params_nnx(config, state, annotations, self.mesh) + + self.assertNotIn("layers", state.params["decoder"]) + self.assertNotIn("layers", annotations.params["decoder"]) + for i in range(num_layers): + self.assertIn(f"layers_{i}", state.params["decoder"]) + sliced_a = state.params["decoder"][f"layers_{i}"]["self_attention"]["query"]["lora_a.kernel"] + expected = jnp.take(original_a, i, axis=0) + self.assertTrue(jnp.array_equal(sliced_a, expected)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/qk_clip_test.py b/tests/unit/qk_clip_test.py index 8cf4b25ea2..3ceda39114 100644 --- a/tests/unit/qk_clip_test.py +++ b/tests/unit/qk_clip_test.py @@ -27,7 +27,7 @@ from maxtext.common.gcloud_stub import is_decoupled from maxtext.layers import attention_mla from maxtext.utils import maxtext_utils -from maxtext.utils.qk_clip_utils import apply_qk_clip, calculate_max_logit_metric +from maxtext.utils.qk_clip_utils import apply_qk_clip, apply_qk_clip_nnx, calculate_max_logit_metric from maxtext.configs import pyconfig from tests.utils.test_helpers import get_test_config_path @@ -500,5 +500,179 @@ def replace_fn(params=None, **kwargs): ) +class _MockAttentionOp(nnx.Module): + """Holds the sowed `max_logits` intermediate at the same tree depth as production.""" + + def __init__(self, max_logits=None): + if max_logits is not None: + self.max_logits = nnx.Intermediate(max_logits) + + +class _MockMLAAttention(nnx.Module): + """`wq_b.kernel` + `wkv_b.kernel` as `nnx.Param`, plus an `attention_op` child.""" + + def __init__(self, wq_b_kernel, wkv_b_kernel, max_logits=None): + self.wq_b = nnx.Module() + self.wq_b.kernel = nnx.Param(wq_b_kernel) + self.wkv_b = nnx.Module() + self.wkv_b.kernel = nnx.Param(wkv_b_kernel) + self.attention_op = _MockAttentionOp(max_logits) + + +class _MockLayer(nnx.Module): + + def __init__(self, attn): + self.self_attention = attn + + +class _MockDecoder(nnx.Module): + + def __init__(self, layer): + self.layers_0 = layer + + +class _MockTransformer(nnx.Module): + + def __init__(self, decoder): + self.decoder = decoder + + +class _MockState: + """Stand-in for `TrainStateNNX`: only `apply_qk_clip_nnx` accesses `.model`.""" + + def __init__(self, model): + self.model = model + + +def _build_mock_nnx_state(wq_b, wkv_b, max_logits=None): + attn = _MockMLAAttention(wq_b, wkv_b, max_logits) + return _MockState(_MockTransformer(_MockDecoder(_MockLayer(attn)))) + + +def _read_kernels(state): + attn = state.model.decoder.layers_0.self_attention + return attn.wq_b.kernel.value, attn.wkv_b.kernel.value + + +class QKClipNNXTest(unittest.TestCase): + """Mirrors `QKClipTest` against the NNX path.""" + + def _make_config(self, threshold, nope_dim, attention_type="mla"): + Config = namedtuple("Config", ["qk_clip_threshold", "qk_nope_head_dim", "attention_type"]) + return Config(qk_clip_threshold=threshold, qk_nope_head_dim=nope_dim, attention_type=attention_type) + + def test_raises_error_for_non_mla(self): + state = _build_mock_nnx_state(jnp.zeros((1, 1, 2)), jnp.zeros((1, 1, 2))) + config = self._make_config(threshold=10.0, nope_dim=4, attention_type="dot_product") + with self.assertRaisesRegex(ValueError, "QK-Clip is only supported for MLA attention"): + apply_qk_clip_nnx(state, {}, config) + + def test_apply_qk_clip_logic(self): + rng = jax.random.PRNGKey(0) + rng_q, rng_kv = jax.random.split(rng) + wq_b = jax.random.normal(rng_q, (2, 2, 6)) + wkv_b = jax.random.normal(rng_kv, (2, 2, 6)) + state = _build_mock_nnx_state(wq_b, wkv_b) + config = self._make_config(threshold=10.0, nope_dim=4) + + # Head 0 logit 20.0 (>tau, scale=0.5); head 1 logit 5.0 (