diff --git a/src/maxtext/experimental/rl/grpo_trainer.py b/src/maxtext/experimental/rl/grpo_trainer.py index cb7889f7b9..282cb31a54 100644 --- a/src/maxtext/experimental/rl/grpo_trainer.py +++ b/src/maxtext/experimental/rl/grpo_trainer.py @@ -59,9 +59,12 @@ from jax import random from flax.linen import partitioning as nn_partitioning +from flax import nnx from flax import struct from flax.nnx import TrainState +from maxtext.common import train_state_nnx + import transformers from ml_goodput_measurement.src.goodput import GoodputRecorder @@ -85,11 +88,12 @@ from maxtext.experimental.rl import grpo_utils from maxtext.common.metric_logger import MetricLogger from maxtext.common.vertex_tensorboard import VertexTensorboardManager -from maxtext.inference import offline_engine from maxtext.utils import exceptions from maxtext.utils import gcs_utils from maxtext.utils import max_logging from maxtext.utils import max_utils +from maxtext.utils import maxtext_utils_nnx +from maxtext.utils import model_creation_utils from maxtext.utils import maxtext_utils from maxtext.utils import sharding from maxtext.utils import train_utils @@ -335,34 +339,244 @@ def grpo_loss_fn(model, config, data, dropout_rng, params, reference_params, is_ return loss, aux +def grpo_loss_fn_nnx(policy_model, config, data, dropout_rng, params, reference_model, is_train=True): + """GRPO loss function for the NNX path. + + See `grpo_loss_fn` above for the algorithm (per-token policy ratio with + clipping, group-relative advantage normalization, optional KL to a frozen + reference). The signature mirrors the Linen `grpo_loss_fn` so callers can + dispatch on the same call shape. The reference forward is wrapped in + `stop_gradient`, so gradients only flow into the policy. + + Args: + policy_model: The policy `nnx.Module`. Both parameters and RNG state are + carried on the module itself. + config: Training configuration object. + data: A batch dict produced by the GRPO input pipeline. + dropout_rng: Unused on the NNX path; kept for signature parity with the + Linen `grpo_loss_fn`. + params: Unused on the NNX path; kept for signature parity. + reference_model: Frozen reference `nnx.Module` used to compute the KL + term. Not updated by the optimizer. + is_train: Whether to run the forward in training mode (dropout enabled). + + Returns: + A tuple `(loss, aux)` where `loss` is a scalar and `aux` is a `LossAux` + dataclass with logging metrics. + """ + del dropout_rng, params # The policy `nnx.Module` carries these. + + prompt_with_completions = data[f"{config.train_data_columns}_completions"] + prompt_completions_position = data[f"{config.train_data_columns}_completions_position"] + prompt_completions_segmentation = data[f"{config.train_data_columns}_completions_segmentation"] + completions_segmentation = data["ar_completions_segmentation"] + + token_logps_policy, intermediate_outputs = grpo_utils.compute_log_probs_nnx( + policy_model, + prompt_with_completions, + prompt_completions_position, + prompt_completions_segmentation, + completions_segmentation, + config, + is_train=is_train, + ) + + completion_target_segmentation = data["ar_completions_segmentation"][..., 1:] + valid_seq_mask = completion_target_segmentation != 0 + + rewards = grpo_utils.dummy_reward_len(valid_seq_mask) + rewards = jnp.array(rewards) + + G = config.num_generations + rewards_grouped = rewards.reshape(-1, G) + group_mean = jnp.mean(rewards_grouped, axis=1) + group_std = jnp.std(rewards_grouped, axis=1) + repeated_group_mean = jnp.repeat(group_mean, G) + repeated_group_std = jnp.repeat(group_std, G) + advantages = (rewards - repeated_group_mean) / (repeated_group_std + EPS) + advantages_exp = advantages[:, None] + + if data["completions_logprobs"] is None: # off-policy + old_per_token_logps = jax.lax.stop_gradient(token_logps_policy) + else: # on-policy + old_per_token_logps = data["completions_logprobs"] + + policy_diff = token_logps_policy - old_per_token_logps + coef_1 = jnp.exp(policy_diff) + coef_2 = jnp.clip(coef_1, 1 - config.grpo_epsilon, 1 + config.grpo_epsilon) + loss_tokens = -jnp.minimum(coef_1 * advantages_exp, coef_2 * advantages_exp) + + if config.grpo_beta != 0.0: + token_logps_ref, _ = grpo_utils.compute_log_probs_nnx( + reference_model, + prompt_with_completions, + prompt_completions_position, + prompt_completions_segmentation, + completions_segmentation, + config, + is_train=False, + ) + token_logps_ref = jax.lax.stop_gradient(token_logps_ref) + token_diff_logps_ref_policy = token_logps_ref - token_logps_policy + per_token_kl = jnp.exp(token_diff_logps_ref_policy) - token_diff_logps_ref_policy - 1 + per_token_kl = per_token_kl * valid_seq_mask + loss_tokens += config.grpo_beta * per_token_kl + + loss_per_example = jnp.sum(loss_tokens * valid_seq_mask, axis=1) / jnp.clip(jnp.sum(valid_seq_mask, axis=1), min=1) + loss = jnp.mean(loss_per_example) + total_weights = jnp.sum(valid_seq_mask) + + moe_lb_loss = 0.0 + if config.num_experts > 1: + moe_lb_losses = maxtext_utils.collect_intermediates_by_suffix(intermediate_outputs, "moe_lb_loss") + if moe_lb_losses: + moe_lb_loss = jnp.mean(jnp.concatenate(moe_lb_losses)) + loss += moe_lb_loss + + if config.grpo_beta != 0.0: + avg_kl = jnp.mean((per_token_kl * valid_seq_mask) / jnp.clip(jnp.sum(valid_seq_mask, axis=1, keepdims=True), min=1)) + else: + avg_kl = None + avg_completion_length = jnp.mean(jnp.sum(data["ar_completions_segmentation"] != 0, axis=1)) + aux = LossAux( + total_loss=loss, + avg_reward=jnp.mean(rewards), + avg_reward_std=jnp.mean(repeated_group_std), + avg_advantage=jnp.mean(advantages), + avg_kl=avg_kl, + completion_length=avg_completion_length, + moe_lb_loss=moe_lb_loss, + total_weights=total_weights, + ) + return loss, aux + + # ----------------------------------------------------------------------------- # Trainer and top level training functions # ----------------------------------------------------------------------------- +def _train_step_nnx(model_graphdef, config, state_mesh_shardings, state, data): + """Run one GRPO training step on the NNX path. + + Reconstructs `TrainStateNNX` from `(model_graphdef, state)`, computes the + GRPO loss and gradients over the policy parameters, applies the gradient + update, and returns the new state with `nnx.Intermediate` values filtered + out (sown forward-pass artifacts must not persist across steps). + + Args: + model_graphdef: NNX `GraphDef` of the `TrainStateNNX`. + config: Training configuration object. + state_mesh_shardings: Sharding spec for the train state. Unused on this + path; kept for signature parity with `train_step`. + state: Flat `nnx.State` matching `model_graphdef`. + data: A batch dict produced by the GRPO input pipeline. + + Returns: + A tuple `(new_state, metrics)`. `new_state` is filtered to exclude + `nnx.Intermediate`. `metrics` is a dict shaped like the Linen path's. + """ + del state_mesh_shardings # Host-offload paths are not yet wired up here. + + if config.gradient_accumulation_steps > 1: + raise NotImplementedError( + "GRPO + pure_nnx + gradient_accumulation_steps>1 not supported yet. " + "Set gradient_accumulation_steps=1 or pure_nnx=False." + ) + + state = nnx.merge(model_graphdef, state) # Reconstruct the TrainStateNNX. + policy_graphdef, curr_params, rest = nnx.split(state.model, nnx.Param, ...) + + def diff_wrapper(param, rest, config, data): + local_model = nnx.merge(policy_graphdef, param, rest, copy=True) + loss, aux = grpo_loss_fn_nnx(local_model, config, data, None, None, state.reference_model, is_train=True) + _, _, new_rest = nnx.split(local_model, nnx.Param, ...) + return loss, (aux, new_rest) + + grad_func = jax.value_and_grad(diff_wrapper, argnums=0, has_aux=True) + (loss, (aux, new_rest)), raw_grads = grad_func(curr_params, rest, config, data) + nnx.update(state.model, new_rest) + + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, None, config.gradient_clipping_threshold) + else: + grads = raw_grads + state.apply_gradients(grads) + new_state = state + + scalar_metrics = { + "learning/loss": loss, + "learning/avg_reward": aux.avg_reward, + "learning/avg_reward_std": aux.avg_reward_std, + "learning/avg_advantage": aux.avg_advantage, + "learning/avg_kl": aux.avg_kl, + "learning/completion_length": aux.completion_length, + "learning/moe_lb_loss": aux.moe_lb_loss, + "learning/total_weights": aux.total_weights, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + } + new_policy_params = nnx.state(new_state.model, nnx.Param) + scalar_metrics["learning/param_norm"] = max_utils.l2norm_pytree(new_policy_params) + metrics = {"scalar": scalar_metrics, "scalars": {}} + + return nnx.state(new_state, nnx.Not(nnx.Intermediate)), metrics + + +def _eval_step_nnx(model_graphdef, config, state, data): + """Run one GRPO evaluation step on the NNX path. + + Reconstructs `TrainStateNNX` and computes the GRPO loss without updating + any parameters. + + Args: + model_graphdef: NNX `GraphDef` of the `TrainStateNNX`. + config: Training configuration object. + state: Flat `nnx.State` matching `model_graphdef`. + data: A batch dict produced by the GRPO input pipeline. + + Returns: + A metrics dict shaped like the Linen `eval_step`'s. + """ + state = nnx.merge(model_graphdef, state) + loss, aux = grpo_loss_fn_nnx(state.model, config, data, None, None, state.reference_model, is_train=False) + metrics = { + "scalar": { + "evaluation/loss": loss, + "evaluation/total_loss": aux.total_loss, + "evaluation/total_weights": aux.total_weights, + "evaluation/moe_lb_loss": aux.moe_lb_loss, + }, + } + return metrics + + def train_step(model, config, state_mesh_shardings, params_shardings, state, data, dropout_rng): - """Performs a single training step of the GRPO algorithm. + """Run one GRPO training step. - This function computes the GRPO loss, calculates gradients, and updates the - model's parameters. It handles gradient accumulation and clipping as configured. - The reference model's parameters are held constant during the update. + Computes the GRPO loss and gradients and applies the update to the policy + parameters; the reference parameters are held constant. The Linen and NNX + paths share this entry point: on the NNX path `model` is an NNX + `GraphDef` and `state` is the matching flat `nnx.State` of a + `TrainStateNNX`. On the Linen path they are the usual `nn.Module` and + `TrainState`. Args: - model: The transformer model to be trained. - config: The training configuration object. - state_mesh_shardings: Pytree of sharding specifications for the training state. - params_shardings: Pytree of sharding specifications for the model parameters. - This argument is not used and is kept to match the signature of other trainers. - state: The current training state, including parameters and optimizer state. - data: A batch of training data, including prompts and generated completions. - dropout_rng: JAX PRNG key for dropout. + model: Linen `nn.Module` or NNX `GraphDef`, depending on `config.pure_nnx`. + config: Training configuration object. + state_mesh_shardings: Pytree of shardings matching `state`. + params_shardings: Param-only shardings, used for gradient accumulation + on the Linen path. Ignored on NNX. + state: Linen `TrainState` or NNX `nnx.State` matching `model`. + data: A batch dict produced by the GRPO input pipeline. + dropout_rng: PRNG key for dropout (Linen only). Returns: - A tuple containing: - - new_state: The updated training state after applying gradients. - - metrics: A dictionary of metrics for logging, including loss, reward, - and gradient norms. + A tuple `(new_state, metrics)`. """ + if config.pure_nnx: + return _train_step_nnx(model, config, state_mesh_shardings, state, data) + state, reference_params = _split_grpo_state(state) state_mesh_shardings, reference_params_sharding = _split_grpo_state(state_mesh_shardings) extra_grpo_args = [reference_params] @@ -473,6 +687,8 @@ def eval_step(model, config, state, data, dropout_rng): Returns: A dictionary of evaluation metrics. """ + if config.pure_nnx: + return _eval_step_nnx(model, config, state, data) reference_params, extra_grpo_args, _loss_fn = [], [], grpo_loss_fn state, reference_params = _split_grpo_state(state) @@ -542,26 +758,51 @@ def setup_train_loop( - eval_data_iterator: The iterator for the evaluation dataset (or None). - state: The initialized training state. """ - # GRPO is Linen-shaped end-to-end (inference goes through Linen MaxEngine). - # Route to Linen regardless of pure_nnx; warn since NNX checkpoints won't load. - if config.pure_nnx or config_inference.pure_nnx: - max_logging.log( - "WARNING: GRPO RL trainer does not yet support pure_nnx natively; " - "running on the Linen path. NNX-format checkpoints will not load correctly here." + if config.pure_nnx != config_inference.pure_nnx: + raise ValueError( + f"config.pure_nnx ({config.pure_nnx}) and config_inference.pure_nnx " f"({config_inference.pure_nnx}) must agree." ) with maybe_record_goodput(recorder, GoodputEvent.TPU_INIT): max_logging.log("Training mesh used for the workload") num_inference_devices = config.inference_devices_per_replica * config.inference_replicas training_devices = jax.devices()[num_inference_devices:] - model = mt.from_config(config, devices=training_devices) + init_rng = jax.random.PRNGKey(config.init_weights_seed) + + if config.pure_nnx: + training_mesh = maxtext_utils.get_mesh_from_config(config, devices=training_devices) + training_rngs = maxtext_utils_nnx.create_nnx_rngs(config, rng_key=init_rng) + model = mt.from_config(config, devices=training_devices, mesh=training_mesh, rngs=training_rngs) + else: + model = mt.from_config(config, devices=training_devices) mesh = model.mesh + max_logging.log("Inference mesh used for the workload") inference_devices = jax.devices()[:num_inference_devices] - inference_model = mt.from_config(config_inference, devices=inference_devices) + if config_inference.pure_nnx: + inference_mesh_obj = maxtext_utils.get_mesh_from_config(config_inference, devices=inference_devices) + inference_rngs = maxtext_utils_nnx.create_nnx_rngs(config_inference, rng_key=init_rng) + inference_model = mt.from_config( + config_inference, devices=inference_devices, mesh=inference_mesh_obj, rngs=inference_rngs + ) + else: + inference_model = mt.from_config(config_inference, devices=inference_devices) inference_mesh = inference_model.mesh - init_rng = jax.random.PRNGKey(config.init_weights_seed) + learning_rate_schedule, tx = train_utils.create_training_optimizer(config, model) - init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) + + if config.pure_nnx: + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh, devices=training_devices) + + def init_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + # Reference uses the same init seed so it starts identical to the policy. + reference_model = _create_model_partial() + # pylint: disable-next=unexpected-keyword-arg + return train_state_nnx.TrainStateNNX(nnx_model, optimizer, reference_model=reference_model) + + else: + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, init_rng) checkpoint_manager = train_utils.create_checkpoint_manager(config, mesh, init_state_fn) with maybe_record_goodput(recorder, GoodputEvent.TRAINING_PREPARATION): @@ -570,16 +811,29 @@ def setup_train_loop( data_iterator, config, mesh, checkpoint_manager, init_state_fn ) - # create inference_state_mesh_shardings from inference_mesh (Linen path; see warning above) - init_inference_state_fn = functools.partial( - maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng - ) + if config_inference.pure_nnx: + _create_inference_partial, _ = model_creation_utils.create_nnx_abstract_model( + config_inference, inference_mesh, devices=inference_devices + ) + + def init_inference_state_fn(): + inference_nnx_model = _create_inference_partial() + return train_state_nnx.TrainStateNNX(inference_nnx_model, None) + + else: + init_inference_state_fn = functools.partial( + maxtext_utils.init_initial_state, inference_model, tx, config_inference, False, init_rng + ) inference_state_mesh_shardings = maxtext_utils.get_abstract_state( config_inference, inference_mesh, init_inference_state_fn, is_training=False )[2] if not config.using_pipeline_parallelism: # The vocab tensor(s) of shape [vocab, embed] (and transpose) are not sharded by stage - sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) + if config.pure_nnx: + params_for_check = nnx.state(state.model, nnx.Param) + sharding.assert_params_sufficiently_sharded(params_for_check, mesh, config.sharding_tolerance) + else: + sharding.assert_params_sufficiently_sharded(state.params, mesh, config.sharding_tolerance) return ( init_rng, @@ -692,10 +946,16 @@ def train_loop(config, config_inference, recorder, state=None): token=config.hf_access_token, ) - if "reference_params" not in state.params: - reference_params = jax.tree.map(jnp.copy, state.params["params"]) - state = _merge_grpo_state(state, reference_params) - state_mesh_shardings = _merge_grpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) + if config.pure_nnx: + # `reference_model` is a sibling field on TrainStateNNX, populated by + # init_state_fn. Nothing to merge here; just verify it is present. + if not hasattr(state, "reference_model"): + raise RuntimeError("NNX GRPO state is missing reference_model; check setup_train_loop.") + else: + if "reference_params" not in state.params: + reference_params = jax.tree.map(jnp.copy, state.params["params"]) + state = _merge_grpo_state(state, reference_params) + state_mesh_shardings = _merge_grpo_state(state_mesh_shardings, state_mesh_shardings.params["params"]) p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( config, model, mesh, state, state_mesh_shardings, train_step, eval_step, eval_data_iterator @@ -703,6 +963,9 @@ def train_loop(config, config_inference, recorder, state=None): data_sharding = sharding.get_input_data_sharding(config, mesh) + # Lazy import: pulls in maxengine and jetstream stubs. + from maxtext.inference import offline_engine # pylint: disable=import-outside-toplevel + inference_engine = offline_engine.OfflineEngine( config=config_inference, mesh=inference_mesh, @@ -717,7 +980,11 @@ def train_loop(config, config_inference, recorder, state=None): metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule) # Write train config params, num model params, and XLA flags to tensorboard - metric_logger.write_setup_info_to_tensorboard(state.params["params"]) + if config.pure_nnx: + params_for_metrics = nnx.state(state.model, nnx.Param) + metric_logger.write_setup_info_to_tensorboard(params_for_metrics) + else: + metric_logger.write_setup_info_to_tensorboard(state.params["params"]) def generation_worker_fn( worker_inference_engine, @@ -841,20 +1108,32 @@ def generation_worker_fn( state, metrics = p_train_step(state, example_batch, train_rng) with jax.profiler.StepTraceAnnotation("transfer data", step_num=step): if step != 0 and step % config.inference_rollouts == 0: - grpo_utils.pathways_reshard( - config_inference, - inference_engine, - {"params": state.params["params"]}, - {"params": state_mesh_shardings.params["params"]}, - mesh, - {"params": inference_state_mesh_shardings.params["params"]}, - ) + if config.pure_nnx: + grpo_utils.pathways_reshard_nnx( + config_inference, + inference_engine, + state.model, + state_mesh_shardings.model, + inference_state_mesh_shardings.model, + ) + else: + grpo_utils.pathways_reshard( + config_inference, + inference_engine, + {"params": state.params["params"]}, + {"params": state_mesh_shardings.params["params"]}, + mesh, + {"params": inference_state_mesh_shardings.params["params"]}, + ) with data_buffer_lock: data_buffer.clear() step_time_delta = datetime.datetime.now() - last_step_completion - state_to_save = _split_grpo_state(state)[0] + # On the Linen path, the reference is embedded in `state.params` and is + # stripped before saving. On the NNX path, the reference is a sibling + # field on TrainStateNNX, so the whole state can be saved as-is. + state_to_save = state if config.pure_nnx else _split_grpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator, step) if config.dump_hlo and step == start_step: @@ -898,7 +1177,7 @@ def generation_worker_fn( metric_logger.buffer_and_write_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: - state_to_save = _split_grpo_state(state)[0] + state_to_save = state if config.pure_nnx else _split_grpo_state(state)[0] checkpointing.maybe_save_checkpoint(checkpoint_manager, state_to_save, config, data_iterator) elif checkpoint_manager is not None: # in case the last checkpoint_period checkpoint is still in progress diff --git a/src/maxtext/experimental/rl/grpo_utils.py b/src/maxtext/experimental/rl/grpo_utils.py index 352a2b3b8d..8989405eab 100644 --- a/src/maxtext/experimental/rl/grpo_utils.py +++ b/src/maxtext/experimental/rl/grpo_utils.py @@ -21,8 +21,9 @@ import jaxtyping from typing import Any, Callable +from flax import nnx + from maxtext.common.common_types import DecoderBlockType -from maxtext.inference.offline_engine import InputData from maxtext.utils import max_logging from maxtext.utils import max_utils @@ -112,6 +113,63 @@ def compute_log_probs( return token_log_probs, intermediate_outputs +def compute_log_probs_nnx( + model, + inputs, + inputs_position, + inputs_segmentation, + completion_segmentation, + config, + is_train=False, +): + """Compute per-token log-probabilities for an NNX policy. + + Mirrors `compute_log_probs` but takes an `nnx.Module` directly: the model + carries its own parameters and RNG state, so there is no `params` or + `rngs` argument. Intermediate values sown by the forward pass are read + back via `nnx.state(model, nnx.Intermediate)`. + + Args: + model: Policy `nnx.Module`. + inputs: A `[B, L]` array of input token ids. + inputs_position: A `[B, L]` array of token positions. + inputs_segmentation: A `[B, L]` array of segment ids. + completion_segmentation: A `[B, L]` array that masks the completion + portion of the sequence. + config: Training configuration object. + is_train: Whether to run the forward in training mode. + + Returns: + A tuple `(token_log_probs, intermediate_outputs)` where + `token_log_probs` has shape `[B, L-1]`. + """ + logits = model( + decoder_input_tokens=inputs, + decoder_positions=inputs_position, + decoder_segment_ids=inputs_segmentation, + enable_dropout=(config.enable_dropout if is_train else False), + ) + intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict() + logits = logits / config.decode_sampling_temperature + + targets = inputs[:, 1:] + shifted_completion_segmentation = jax.lax.dynamic_slice( + completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1) + ) + shifted_completion_segmentation = jnp.pad( + shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0 + ) + mask = shifted_completion_segmentation[..., None] + mask = jnp.broadcast_to(mask, logits.shape) + masked_logits = jnp.where(mask, logits, -jnp.inf) + log_probs = jax.nn.log_softmax(masked_logits, axis=-1) + log_probs = jnp.where(mask, log_probs, -0.0) + log_probs = log_probs[:, :-1, :] + token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0] + token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1] + return token_log_probs, intermediate_outputs + + def generate_offline_completions(config, tokenizer_model, inference_engine, data): """Generates completions for a batch of prompts using an offline engine. @@ -125,6 +183,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data The input `data` dictionary updated with the generated completions, segmentations, positions, and log-probabilities. """ + # Lazy import: pulls in maxengine and jetstream stubs, which we only want to + # touch when this function is actually called (i.e. during a real GRPO run). + from maxtext.inference.offline_engine import InputData # pylint: disable=import-outside-toplevel + data[config.train_data_columns] = np.asarray( jnp.repeat(data[config.train_data_columns], config.num_generations, axis=0) ) @@ -175,6 +237,40 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data return data +def pathways_reshard_nnx( + config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model +): + """Reshard NNX policy params onto the inference mesh. + + Splits the policy `nnx.Param` state out of the training-side TrainStateNNX + model substate, reshards it onto the inference mesh, and pushes the + resharded params into the inference engine. Requires `scan_layers=True`; + the Linen `unscan_train_state_params` helper has no NNX equivalent yet. + + Args: + config: Training configuration object. + inference_engine: Inference engine to receive the resharded params. + policy_state_model: Training-side `state.model` substate. + source_shardings_model: Shardings for `policy_state_model`. Unused + because the same shardings are already attached to the params. + destination_shardings_model: Shardings for the inference-side model. + """ + if not config.scan_layers: + raise NotImplementedError( + "GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False." + ) + policy_params = nnx.state(policy_state_model, nnx.Param) + source_param_shardings = nnx.state(source_shardings_model, nnx.Param) + dest_param_shardings = nnx.state(destination_shardings_model, nnx.Param) + del source_param_shardings # Already encoded on policy_params. + with ( + jax.transfer_guard_device_to_host("disallow_explicit"), + jax.transfer_guard_host_to_device("disallow_explicit"), + ): + resharded_params = reshard_pytree(policy_params, dest_param_shardings) + inference_engine.update_params(resharded_params) + + def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings): """Reshards model parameters from training to inference sharding. diff --git a/src/maxtext/inference/maxengine/maxengine.py b/src/maxtext/inference/maxengine/maxengine.py index 778bc04285..ec6b934e5e 100644 --- a/src/maxtext/inference/maxengine/maxengine.py +++ b/src/maxtext/inference/maxengine/maxengine.py @@ -447,12 +447,12 @@ def _load_params_nnx(self, params, rng): return params_state def load_single_adapter(self, adapter_path): + """Load a single LoRA adapter from `adapter_path`. + + Expects `adapter_config.json` at the root and adapter weights under + `/0/items`. The returned `params` follows the same tree + shape as `self.abstract_params` (NNX or Linen, depending on the engine). """ - Load Single adapter from adapter_path. - Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`. - """ - if self.config.pure_nnx: - raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.") adapter_config_path = os.path.join(adapter_path, "adapter_config.json") adapter_weights_path = os.path.join(adapter_path, "0", "items") @@ -475,14 +475,20 @@ def apply_adapter(self, base_params, adapter_config, adapter_params): lora_rank = int(adapter_config["r"]) lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank - lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor) + if self.config.pure_nnx: + lora_utils.apply_lora_on_base_params_nnx(base_params, adapter_params, lora_scale_factor) + else: + lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor) def unapply_adapter(self, base_params, adapter_config, adapter_params): """Unapply the adapter params from the merged params to get back the base params.""" lora_rank = int(adapter_config["r"]) lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank - lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor) + if self.config.pure_nnx: + lora_utils.unapply_lora_from_base_params_nnx(base_params, adapter_params, lora_scale_factor) + else: + lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor) def quantize_params(self, state, rng: PRNGKeyType | None = None): """Forward pass to quantize decode params.""" diff --git a/src/maxtext/utils/lora_utils.py b/src/maxtext/utils/lora_utils.py index ba7d540dae..482a3e6ee6 100644 --- a/src/maxtext/utils/lora_utils.py +++ b/src/maxtext/utils/lora_utils.py @@ -14,7 +14,7 @@ """Common LoRA utils needed to support LoRA adapters.""" - +from collections.abc import Mapping from functools import partial import json import os @@ -38,6 +38,9 @@ from maxtext.utils import sharding from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR +# NNX-only imports (train_state_nnx, model_creation_utils) are loaded lazily +# inside the NNX dispatch branches so the Linen-only flow doesn't pull them in. + def apply_lora_on_base_params(base_params, lora_params, lora_scale_factor=1.0): """ @@ -118,8 +121,21 @@ def unapply_lora_recursively(base_params, lora_params, module_name): def load_adapter(config, base_abstract_state_params, adapter_config_path, adapter_weights_path): - """ - Load the LoRA weights into a PyTree and return it. + """Load a LoRA adapter from disk and return its parameters. + + When `config.pure_nnx` is True, `base_abstract_state_params` is the NNX + abstract param state (no outer `params` wrapper) and the returned + `lora_params` follows the same shape. Otherwise both use the Linen tree. + + Args: + config: Top-level MaxText config. + base_abstract_state_params: Abstract param state of the base model. + adapter_config_path: Path to `adapter_config.json` (local or GCS). + adapter_weights_path: Path to the adapter weights directory. + + Returns: + A tuple `(lora_params, lora_config)`. Both are None when + `adapter_config_path` is empty. """ # Load LoRA weights lora_params = None @@ -137,7 +153,10 @@ def load_adapter(config, base_abstract_state_params, adapter_config_path, adapte if not gcs_utils.gcs_path_exists(f"{adapter_weights_path}/commit_success.txt"): raise FileNotFoundError(f"Failed to read lora_weights from {adapter_weights_path}.") - lora_state, _ = get_lora_abstract_state(base_abstract_state_params, lora_config) + if config.pure_nnx: + lora_state, _ = get_lora_abstract_state_nnx(base_abstract_state_params, lora_config) + else: + lora_state, _ = get_lora_abstract_state(base_abstract_state_params, lora_config) with nn_partitioning.axis_rules(config.logical_axis_rules): lora_params = checkpointing.load_params_from_path( @@ -152,22 +171,27 @@ def load_adapter(config, base_abstract_state_params, adapter_config_path, adapte def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, lora_adapter_path): - """We initialize the model and optimizer state, and optionally load from a - checkpoint as necessary. + """Initialize the LoRA train state and optionally restore it from a checkpoint. + + On the NNX path, `model` is unused; the abstract state is built from + `model_creation_utils.create_nnx_abstract_model` and `lora_state.params` + follows the NNX shape. On the Linen path the existing `{"params": ...}` + tree shape is preserved. Args: - model: the flax model to initialize - tx: the optax.GradientTransformation - config: config object - rng: jax.prng key - mesh: jax.devices() mesh - checkpoint_manager: an Orbax checkpointing.CheckpointManager object - lora_adapter_path: Path of the LoRA adapter which is expected to have - `adapter_config.json` and adapter weights + model: Linen `nn.Module` used on the Linen path; ignored on NNX. + data_iterator: Data iterator passed through to `load_state_if_possible`. + tx: Optax gradient transformation for the optimizer. + config: Top-level MaxText config. + rng: PRNG key used for the Linen init. + mesh: JAX device mesh. + checkpoint_manager: Orbax `CheckpointManager` for the adapter. + lora_adapter_path: Path to the adapter directory containing + `adapter_config.json`. Returns: - state: the initialized train state - state_mesh_annotations: the mesh annotations for the train state + A tuple `(lora_config, lora_state, lora_state_annotations)`. All three + are None when `lora_adapter_path` is empty. """ lora_state = None @@ -176,21 +200,32 @@ def setup_initial_lora_state(model, data_iterator, tx, config, rng, mesh, checkp if lora_adapter_path: max_logging.log(f"Setting initial state of LoRA with lora_adapter_path = {lora_adapter_path}") - # LoRA adapters on disk are Linen-format and downstream expects Linen TrainState. - # Route to Linen regardless of pure_nnx; native NNX LoRA is a separate effort. if config.pure_nnx: - max_logging.log( - "WARNING: LoRA does not yet support pure_nnx natively; " - "running on the Linen path. NNX-format checkpoints will not load correctly here." - ) - init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) + # pylint: disable=import-outside-toplevel + from maxtext.common import train_state_nnx + from maxtext.utils import model_creation_utils + + _create_model_partial, _ = model_creation_utils.create_nnx_abstract_model(config, mesh) + + def create_train_state_fn(): + nnx_model = _create_model_partial() + optimizer = nnx.Optimizer(nnx_model, tx, wrt=nnx.Param) + return train_state_nnx.TrainStateNNX(nnx_model, optimizer) + + init_state_fn = create_train_state_fn + else: + init_state_fn = partial(maxtext_utils.init_initial_state, model, tx, config, True, rng) unboxed_abstract_state, _, _ = maxtext_utils.get_abstract_state(config, mesh, init_state_fn, True) lora_config_path = lora_adapter_path + "adapter_config.json" lora_config = gcs_utils.read_json_from_gcs(lora_config_path) - lora_state, lora_state_annotations = get_lora_abstract_state(unboxed_abstract_state.params, lora_config) + if config.pure_nnx: + base_abstract_params = _nnx_param_subtree(unboxed_abstract_state) + lora_state, lora_state_annotations = get_lora_abstract_state_nnx(base_abstract_params, lora_config) + else: + lora_state, lora_state_annotations = get_lora_abstract_state(unboxed_abstract_state.params, lora_config) lora_weights_path = f"{lora_adapter_path}/0/items" @@ -610,3 +645,188 @@ def _map_to_state(path, variable): nnx.update(trainer.model, abstract_lora_params) max_logging.log(f"LoRA restore complete from '{lora_restore_path}'.") return trainer + + +# NNX-shaped LoRA helpers. +# +# The Linen walkers above use `isinstance(x, dict)` and unwrapped leaves. NNX +# trees use `nnx.State` (a Mapping that is not a dict) and Variable-wrapped +# leaves, so a separate set of walkers is needed. The math (W += B @ A * s) +# is identical to the Linen path. + + +def _is_nnx_branch(x): + """Return True if `x` should be recursed into as a sub-tree.""" + return isinstance(x, Mapping) + + +def _nnx_param_subtree(unboxed_abstract_state): + """Return the `model` substate, peeling off the outer `TrainStateNNX` wrapping.""" + return unboxed_abstract_state["model"] if "model" in unboxed_abstract_state else unboxed_abstract_state + + +def apply_lora_on_base_params_nnx(base_params, lora_params, lora_scale_factor=1.0): + """Apply LoRA deltas to `base_params` on the NNX path. + + Standard LoRA decomposition: `W_new = W + lora_a @ lora_b * scale`, where + `lora_a` is the down-projection of shape `(emb, rank)` and `lora_b` is + the up-projection of shape `(rank, num_heads, head_dim)`. Mutates + `base_params` in place. + + Mirrors `apply_lora_on_base_params` but operates on an `nnx.State`-shaped + tree (a nested `Mapping` whose leaves are arrays). The recursion follows + the natural nested-dict structure of the lora tree, matching the Linen + sibling above. + """ + + def recurse(base_node, lora_node, path): + # Leaf-level node for a target module: it contains lora_a and lora_b + # side by side, so we handle the pair together and stop descending. + if "lora_a.kernel" in lora_node: + lora_a = lora_node["lora_a.kernel"] + lora_b = lora_node["lora_b.kernel"] + if lora_a is not None and lora_b is not None: + base_node["kernel"] = base_node["kernel"] + jnp.einsum("er,rnd->end", lora_a, lora_b) * lora_scale_factor + return + for name, lora_child in lora_node.items(): + if _is_nnx_branch(lora_child): + recurse(base_node[name], lora_child, f"{path}.{name}") + elif lora_child is not None: + raise ValueError(f"Unexpected non-lora key ({path}.{name}) in lora_params") + + recurse(base_params, lora_params, "") + + +def unapply_lora_from_base_params_nnx(base_params, lora_params, lora_scale_factor=1.0): + """Unapply LoRA deltas from `base_params` on the NNX path. + + Symmetric inverse of `apply_lora_on_base_params_nnx`: `W -= lora_a @ lora_b * scale` + at each target module. Mutates `base_params` in place. + """ + + def recurse(base_node, lora_node, path): + # Leaf-level node for a target module: handle the lora_a / lora_b pair together. + if "lora_a.kernel" in lora_node: + lora_a = lora_node["lora_a.kernel"] + lora_b = lora_node["lora_b.kernel"] + if lora_a is not None and lora_b is not None: + base_node["kernel"] = base_node["kernel"] - jnp.einsum("er,rnd->end", lora_a, lora_b) * lora_scale_factor + return + for name, lora_child in lora_node.items(): + if _is_nnx_branch(lora_child): + recurse(base_node[name], lora_child, f"{path}.{name}") + elif lora_child is not None: + raise ValueError(f"Unexpected non-lora key ({path}.{name}) in lora_params") + + recurse(base_params, lora_params, "") + + +def get_lora_abstract_state_nnx(base_abstract_params, lora_config): + """Build an abstract LoRA state from an NNX-shaped base abstract state. + + Walks `base_abstract_params` (the abstract `state.model` substate) and + emits a parallel tree with `lora_a.kernel` and `lora_b.kernel` leaves at + target attention paths, and `None` elsewhere. Shardings are derived from + the matching base leaves. + + Args: + base_abstract_params: NNX abstract param state whose leaves are + `jax.ShapeDtypeStruct`. + lora_config: Adapter config dict containing `target_modules` and `r`. + + Returns: + A tuple `(lora_state, lora_state_mesh_annotations)` matching the shape + returned by the Linen `get_lora_abstract_state`. + """ + other_lora_format_to_jax_format = { + "q_proj": "self_attention.query", + "k_proj": "self_attention.key", + "v_proj": "self_attention.value", + "o_proj": "self_attention.out", + } + + lora_target_modules = [other_lora_format_to_jax_format.get(s, s) for s in lora_config["target_modules"]] + lora_rank = int(lora_config["r"]) + + def get_lora_param_shape(base_array_shape, lora_module): + if len(base_array_shape) > 4: + raise ValueError(f"Unsupported base array shape {base_array_shape} (>4D)") + if lora_module in ("self_attention.query", "self_attention.key", "self_attention.value"): + lora_a_shape = base_array_shape[:-2] + (lora_rank,) + lora_b_shape = (lora_rank,) + base_array_shape[1:] + elif lora_module == "self_attention.out": + lora_a_shape = base_array_shape[:-1] + (lora_rank,) + if len(base_array_shape) == 4: + lora_b_shape = (lora_rank, base_array_shape[1], base_array_shape[-1]) + else: + lora_b_shape = (lora_rank, base_array_shape[-1]) + else: + raise ValueError(f"Unsupported lora_module={lora_module}") + return lora_a_shape, lora_b_shape + + def get_lora_param_sharding(base_param_sharding, lora_module): + if base_param_sharding is None: + return None, None + base_pspec = base_param_sharding.spec + if len(base_pspec) > 4: + raise ValueError("PartitionSpec size > 4 not supported") + if lora_module in ("self_attention.query", "self_attention.key", "self_attention.value"): + lora_a_pspec = jax.sharding.PartitionSpec(*(base_pspec[:-2] + ((),))) + lora_b_pspec = jax.sharding.PartitionSpec(*(((),) + base_pspec[1:])) + elif lora_module == "self_attention.out": + lora_a_pspec = jax.sharding.PartitionSpec(*(base_pspec[:-1] + ((),))) + if len(base_pspec) == 4: + lora_b_pspec = jax.sharding.PartitionSpec((), base_pspec[1], base_pspec[-1]) + else: + lora_b_pspec = jax.sharding.PartitionSpec((), base_pspec[-1]) + else: + raise ValueError(f"Unsupported lora_module={lora_module}") + mesh = base_param_sharding.mesh + mem_kind = base_param_sharding.memory_kind + return ( + jax.sharding.NamedSharding(mesh=mesh, spec=lora_a_pspec, memory_kind=mem_kind), + jax.sharding.NamedSharding(mesh=mesh, spec=lora_b_pspec, memory_kind=mem_kind), + ) + + def module_is_target(module_path): + for tgt in lora_target_modules: + if tgt in module_path: + return tgt + return None + + def add_lora(out_node, base_node, path): + for name, child in base_node.items(): + if _is_nnx_branch(child): + out_node[name] = {} + add_lora(out_node[name], child, f"{path}.{name}") + else: + if name not in ("kernel", "scale", "embedding"): + raise ValueError(f"Unexpected key={name} in base abstract params at {path}") + if not isinstance(child, jax.ShapeDtypeStruct): + raise ValueError(f"Unexpected leaf type {type(child).__name__} at {path}.{name}") + target_module = module_is_target(path) + if target_module is not None: + a_shape, b_shape = get_lora_param_shape(child.shape, target_module) + a_sharding, b_sharding = get_lora_param_sharding(child.sharding, target_module) + out_node["lora_a.kernel"] = jax.ShapeDtypeStruct(shape=a_shape, dtype=child.dtype, sharding=a_sharding) + out_node["lora_b.kernel"] = jax.ShapeDtypeStruct(shape=b_shape, dtype=child.dtype, sharding=b_sharding) + else: + out_node[name] = None + + lora_abstract_params = {} + add_lora(lora_abstract_params, base_abstract_params, "") + + unboxed_abstract_lora_state = train_state.TrainState( + step=0, apply_fn=None, params=lora_abstract_params, tx=None, opt_state={} # type: ignore + ) + lora_state_mesh_annotations = train_state.TrainState( + step=0, + apply_fn=None, + params=jax.tree_util.tree_map( + lambda x: x.sharding.spec if x.sharding is not None else None, + lora_abstract_params, + ), + tx=None, # type: ignore + opt_state={}, + ) + return unboxed_abstract_lora_state, lora_state_mesh_annotations diff --git a/tests/integration/maxengine_test.py b/tests/integration/maxengine_test.py index efe8c6d55d..c4599f1438 100644 --- a/tests/integration/maxengine_test.py +++ b/tests/integration/maxengine_test.py @@ -262,11 +262,16 @@ def test_quantize_raises_for_nnx(self): with self.assertRaises(NotImplementedError): engine.load_params(rng=self.rng) - def test_lora_raises_for_nnx(self): - """NNX path raises NotImplementedError for LoRA.""" + def test_lora_load_single_adapter_reaches_loader_on_nnx(self): + """pure_nnx + LoRA: load_single_adapter dispatches to the NNX loader. + + A nonexistent adapter path should raise FileNotFoundError from the + loader itself. A NotImplementedError here would mean the dispatch + never reached the loader (i.e. the legacy carve-out is still in place). + """ cfg = self._init_nnx_pyconfig() engine = maxengine.MaxEngine(cfg, jax.devices()) - with self.assertRaises(NotImplementedError): + with self.assertRaises(FileNotFoundError): engine.load_single_adapter("/nonexistent/adapter/path") def test_prefill_multisampling_nnx(self): diff --git a/tests/unit/grpo_nnx_test.py b/tests/unit/grpo_nnx_test.py new file mode 100644 index 0000000000..77f6361b9b --- /dev/null +++ b/tests/unit/grpo_nnx_test.py @@ -0,0 +1,231 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for `grpo_loss_fn_nnx`, `compute_log_probs_nnx`, plus a small +Linen-path regression block (the repo's existing Linen GRPO integration test +is TPU-only).""" + +import types +import unittest + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx + +from maxtext.experimental.rl import grpo_trainer +from maxtext.experimental.rl import grpo_utils + + +class _MockTransformer(nnx.Module): + """Tiny NNX module that responds to the kwargs `compute_log_probs_nnx` uses.""" + + def __init__(self, vocab_size: int, embed_dim: int, rngs: nnx.Rngs): + self.embed = nnx.Embed(vocab_size, embed_dim, rngs=rngs) + self.proj = nnx.Linear(embed_dim, vocab_size, rngs=rngs) + + def __call__( + self, + decoder_input_tokens, + decoder_positions=None, + decoder_segment_ids=None, + enable_dropout=False, + **kwargs, + ): + del decoder_positions, decoder_segment_ids, enable_dropout, kwargs + return self.proj(self.embed(decoder_input_tokens)) + + +def _make_grpo_config(**overrides): + """Minimal config namespace covering every field `grpo_loss_fn_nnx` reads.""" + base = { + "train_data_columns": "prompt", + "num_generations": 2, + "grpo_epsilon": 0.2, + "grpo_beta": 0.1, + "num_experts": 1, + "decode_sampling_temperature": 1.0, + "enable_dropout": False, + "use_dpo": False, + } + base.update(overrides) + return types.SimpleNamespace(**base) + + +def _make_grpo_batch(B=2, G=2, S=6): + """Minimal GRPO batch: `B` prompts, `G` generations each (total `B*G`), seq length `S`.""" + total = B * G + prompts = jnp.tile(jnp.arange(S, dtype=jnp.int32), (total, 1)) + return { + "prompt_completions": prompts, + "prompt_completions_position": prompts, + "prompt_completions_segmentation": jnp.ones((total, S), dtype=jnp.int32), + "ar_completions_segmentation": jnp.array([[0, 0, 1, 1, 1, 0]] * total, dtype=jnp.int32), + "completions_logprobs": None, # off-policy + } + + +class TestGrpoLossFnNnx(unittest.TestCase): + """Behavior of `grpo_loss_fn_nnx` on a synthetic policy / reference pair.""" + + def setUp(self): + self.policy = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + # Use the same seed so the reference starts identical to the policy. + self.reference = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + self.config = _make_grpo_config() + self.data = _make_grpo_batch() + + def test_aux_structure_matches_linen(self): + """`grpo_loss_fn_nnx` returns the same `LossAux` dataclass shape as `grpo_loss_fn`.""" + loss, aux = grpo_trainer.grpo_loss_fn_nnx( + self.policy, self.config, self.data, None, None, self.reference, is_train=True + ) + self.assertIsInstance(aux, grpo_trainer.LossAux) + for field in ( + "total_loss", + "avg_reward", + "avg_reward_std", + "avg_advantage", + "completion_length", + "moe_lb_loss", + "total_weights", + ): + self.assertTrue(hasattr(aux, field), f"aux missing field {field}") + self.assertTrue(jnp.isfinite(loss)) + + def test_unused_dropout_rng_and_params_args_are_ignored(self): + """`dropout_rng` and `params` are positional placeholders, so their values do not affect the loss.""" + a = grpo_trainer.grpo_loss_fn_nnx(self.policy, self.config, self.data, None, None, self.reference, is_train=True) + b = grpo_trainer.grpo_loss_fn_nnx( + self.policy, self.config, self.data, jax.random.key(99), {"junk": 1}, self.reference, is_train=True + ) + np.testing.assert_allclose(np.asarray(a[0]), np.asarray(b[0]), rtol=1e-6) + + def test_identical_policy_and_reference_zero_kl(self): + """When the policy and reference are identical, the per-token KL is zero.""" + cfg = _make_grpo_config(grpo_beta=0.5) + _, aux = grpo_trainer.grpo_loss_fn_nnx(self.policy, cfg, self.data, None, None, self.reference, is_train=True) + self.assertIsNotNone(aux.avg_kl) + np.testing.assert_allclose(np.asarray(aux.avg_kl), 0.0, atol=1e-5) + + def test_grpo_beta_zero_avg_kl_is_none(self): + cfg = _make_grpo_config(grpo_beta=0.0) + _, aux = grpo_trainer.grpo_loss_fn_nnx(self.policy, cfg, self.data, None, None, self.reference, is_train=True) + self.assertIsNone(aux.avg_kl) + + def test_value_and_grad_flows_only_to_policy(self): + """`nnx.value_and_grad` over the policy yields finite grads; reference is left alone.""" + + def loss_only(policy_model): + loss, _ = grpo_trainer.grpo_loss_fn_nnx( + policy_model, self.config, self.data, None, None, self.reference, is_train=True + ) + return loss + + # nnx.value_and_grad returns (value, grad_state) where grad_state holds nnx.Param leaves. + _, grads = nnx.value_and_grad(loss_only, argnums=0)(self.policy) + leaves = jax.tree_util.tree_leaves(grads) + self.assertGreater(len(leaves), 0) + for leaf in leaves: + self.assertTrue(np.all(np.isfinite(np.asarray(leaf))), "policy grad has non-finite entries") + + +class TestComputeLogProbsNnx(unittest.TestCase): + """Shape contract of `compute_log_probs_nnx`.""" + + def test_returns_correct_shape(self): + config = _make_grpo_config() + data = _make_grpo_batch() + model = _MockTransformer(vocab_size=8, embed_dim=4, rngs=nnx.Rngs(0)) + log_probs, _ = grpo_utils.compute_log_probs_nnx( + model, + data["prompt_completions"], + data["prompt_completions_position"], + data["prompt_completions_segmentation"], + data["ar_completions_segmentation"], + config, + is_train=False, + ) + # Inputs are [B, S] and log_probs are [B, S - 1]. + self.assertEqual(log_probs.shape, (data["prompt_completions"].shape[0], data["prompt_completions"].shape[1] - 1)) + + +# --------------------------------------------------------------------------- +# Linen-path regression smoke tests +# --------------------------------------------------------------------------- + + +class _MockLinenTransformer(nn.Module): + """Tiny Linen module with the `model.apply(...)` signature that Linen `compute_log_probs` expects.""" + + vocab_size: int + embed_dim: int + + @nn.compact + def __call__(self, inputs, positions, decoder_segment_ids=None, enable_dropout=False): + del positions, decoder_segment_ids, enable_dropout + embed = nn.Embed(num_embeddings=self.vocab_size, features=self.embed_dim, name="embed")(inputs) + return nn.Dense(features=self.vocab_size, name="proj")(embed) + + +class TestLinenGrpoRegression(unittest.TestCase): + """Smoke tests that the Linen `grpo_loss_fn` and `compute_log_probs` still run on Linen-shaped inputs.""" + + def setUp(self): + self.config = _make_grpo_config() + self.config.pure_nnx = False # Force the Linen dispatch branch. + self.config.gradient_accumulation_steps = 1 + self.data = _make_grpo_batch() + self.model = _MockLinenTransformer(vocab_size=8, embed_dim=4) + rng = jax.random.key(0) + inputs = self.data["prompt_completions"] + self.params = self.model.init(rng, inputs, inputs, decoder_segment_ids=jnp.ones_like(inputs), enable_dropout=False) + self.reference_params = jax.tree_util.tree_map(jnp.copy, self.params) + + def test_linen_grpo_loss_fn_still_runs(self): + """Linen `grpo_loss_fn` returns a finite loss + a `LossAux`.""" + loss, aux = grpo_trainer.grpo_loss_fn( + self.model, + self.config, + self.data, + jax.random.key(1), + self.params, + self.reference_params["params"], # On Linen, reference_params is the inner subtree. + is_train=True, + ) + self.assertTrue(jnp.isfinite(loss)) + self.assertTrue(hasattr(aux, "total_loss")) + self.assertTrue(hasattr(aux, "moe_lb_loss")) + self.assertTrue(hasattr(aux, "total_weights")) + + def test_linen_compute_log_probs_still_runs(self): + """Linen `compute_log_probs` produces shape `[B, S-1]`.""" + log_probs, _ = grpo_utils.compute_log_probs( + self.model, + self.params, + self.data["prompt_completions"], + self.data["prompt_completions_position"], + self.data["prompt_completions_segmentation"], + self.data["ar_completions_segmentation"], + self.config, + is_train=False, + rngs={"dropout": jax.random.key(2), "params": jax.random.key(3)}, + ) + S = self.data["prompt_completions"].shape[1] + self.assertEqual(log_probs.shape, (self.data["prompt_completions"].shape[0], S - 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/lora_utils_nnx_test.py b/tests/unit/lora_utils_nnx_test.py new file mode 100644 index 0000000000..322425e674 --- /dev/null +++ b/tests/unit/lora_utils_nnx_test.py @@ -0,0 +1,294 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the NNX-shaped LoRA helpers in `lora_utils`, with a small +Linen regression block at the end.""" + +import unittest + +import jax +import jax.numpy as jnp +import numpy as np + +from maxtext.utils.lora_utils import ( + apply_lora_on_base_params, + apply_lora_on_base_params_nnx, + get_lora_abstract_state_nnx, + unapply_lora_from_base_params, + unapply_lora_from_base_params_nnx, +) + + +# --------------------------------------------------------------------------- +# Fake abstract state builders (mirror the NNX vs. Linen tree shapes) +# --------------------------------------------------------------------------- + + +def _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4, dtype=jnp.float32): + """Tiny NNX-shaped abstract state for one attention block.""" + + def _sds(shape): + return jax.ShapeDtypeStruct(shape=shape, dtype=dtype, sharding=None) + + return { + "decoder": { + "layers": { + "self_attention": { + "query": {"kernel": _sds((emb, num_heads, head_dim))}, + "key": {"kernel": _sds((emb, num_heads, head_dim))}, + "value": {"kernel": _sds((emb, num_heads, head_dim))}, + "out": {"kernel": _sds((emb, num_heads, head_dim))}, + }, + "mlp": {"wi": {"kernel": _sds((emb, 4 * emb))}}, + }, + "shared_embedding": {"embedding": _sds((100, emb))}, + }, + } + + +def _make_linen_attention_abstract(emb=8, num_heads=2, head_dim=4, dtype=jnp.float32): + """Linen-shaped equivalent (with the `{"params": ...}` outer wrap).""" + return {"params": _make_nnx_attention_abstract(emb, num_heads, head_dim, dtype)} + + +def _lora_config(rank=4, alpha=8.0, target_modules=("q_proj", "v_proj")): + return { + "r": rank, + "lora_alpha": alpha, + "target_modules": list(target_modules), + } + + +# --------------------------------------------------------------------------- +# get_lora_abstract_state_nnx +# --------------------------------------------------------------------------- + + +class TestGetLoraAbstractStateNnx(unittest.TestCase): + """`get_lora_abstract_state_nnx` shape, sharding, and error-path coverage.""" + + def test_lora_shapes_for_query_and_value(self): + abs_params = _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4) + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(rank=4)) + attn = state.params["decoder"]["layers"]["self_attention"] + + a = attn["query"]["lora_a.kernel"] + b = attn["query"]["lora_b.kernel"] + self.assertEqual(a.shape, (8, 4)) + self.assertEqual(b.shape, (4, 2, 4)) + self.assertEqual(a.dtype, jnp.float32) + self.assertEqual(b.dtype, jnp.float32) + + a = attn["value"]["lora_a.kernel"] + b = attn["value"]["lora_b.kernel"] + self.assertEqual(a.shape, (8, 4)) + self.assertEqual(b.shape, (4, 2, 4)) + + def test_non_target_modules_emit_none_leaves(self): + abs_params = _make_nnx_attention_abstract() + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(target_modules=("q_proj",))) + attn = state.params["decoder"]["layers"]["self_attention"] + self.assertIn("lora_a.kernel", attn["query"]) + self.assertIsNone(attn["key"]["kernel"]) + self.assertIsNone(attn["value"]["kernel"]) + self.assertIsNone(attn["out"]["kernel"]) + self.assertIsNone(state.params["decoder"]["layers"]["mlp"]["wi"]["kernel"]) + self.assertIsNone(state.params["decoder"]["shared_embedding"]["embedding"]) + + def test_o_proj_has_distinct_shape(self): + abs_params = _make_nnx_attention_abstract(emb=8, num_heads=2, head_dim=4) + state, _ = get_lora_abstract_state_nnx(abs_params, _lora_config(rank=3, target_modules=("o_proj",))) + out = state.params["decoder"]["layers"]["self_attention"]["out"] + a = out["lora_a.kernel"] + b = out["lora_b.kernel"] + # For a 3D base (emb, num_heads, head_dim): lora_a.shape ends with rank, + # lora_b shape is (rank, head_dim). + self.assertEqual(a.shape, (8, 2, 3)) + self.assertEqual(b.shape, (3, 4)) + + def test_unsupported_leaf_type_raises(self): + bad = {"decoder": {"layers": {"self_attention": {"query": {"kernel": jnp.zeros((4, 2, 2))}}}}} + with self.assertRaises(ValueError): + get_lora_abstract_state_nnx(bad, _lora_config()) + + def test_unexpected_leaf_name_raises(self): + bad = {"decoder": {"layers": {"self_attention": {"query": {"weight": jax.ShapeDtypeStruct((4, 2), jnp.float32)}}}}} + with self.assertRaises(ValueError): + get_lora_abstract_state_nnx(bad, _lora_config()) + + # Linen-vs-NNX numerical parity is covered by TestApplyLoraNnx.test_numerical_parity_with_linen_apply. + + +# --------------------------------------------------------------------------- +# apply / unapply on NNX-shape pure dicts +# --------------------------------------------------------------------------- + + +def _concrete_base(rng=None, emb=4, num_heads=2, head_dim=3): + """Concrete arrays mirroring the abstract structure used above (NNX-shape).""" + if rng is None: + rng = jax.random.key(0) + k1, k2, k3, k4, k5, k6 = jax.random.split(rng, 6) + shape_attn = (emb, num_heads, head_dim) + return { + "decoder": { + "layers": { + "self_attention": { + "query": {"kernel": jax.random.normal(k1, shape_attn)}, + "key": {"kernel": jax.random.normal(k2, shape_attn)}, + "value": {"kernel": jax.random.normal(k3, shape_attn)}, + "out": {"kernel": jax.random.normal(k4, shape_attn)}, + }, + "mlp": {"wi": {"kernel": jax.random.normal(k5, (emb, 4 * emb))}}, + }, + "shared_embedding": {"embedding": jax.random.normal(k6, (100, emb))}, + }, + } + + +def _build_lora_params(base, lora_config_dict, rng): + """Build a concrete LoRA tree (random arrays) matching `base`.""" + abs_tree = jax.tree_util.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=None), base) + lora_state, _ = get_lora_abstract_state_nnx(abs_tree, lora_config_dict) + + def _to_concrete(leaf, rng_key): + if leaf is None: + return None + return jax.random.normal(rng_key, leaf.shape, leaf.dtype) + + leaves, tree = jax.tree_util.tree_flatten(lora_state.params, is_leaf=lambda x: x is None) + rngs = jax.random.split(rng, max(1, len(leaves))) + out_leaves = [_to_concrete(l, r) for l, r in zip(leaves, rngs)] + return jax.tree_util.tree_unflatten(tree, out_leaves) + + +class TestApplyLoraNnx(unittest.TestCase): + """`apply_lora_on_base_params_nnx` round-trip and Linen-vs-NNX parity.""" + + def test_apply_then_unapply_is_identity(self): + rng = jax.random.key(42) + base_orig = _concrete_base(rng) + base = jax.tree_util.tree_map(jnp.copy, base_orig) + lora = _build_lora_params(base, _lora_config(rank=2, target_modules=("q_proj", "v_proj")), jax.random.key(7)) + apply_lora_on_base_params_nnx(base, lora, lora_scale_factor=0.5) + # The query and value kernels are targets and must have changed. + self.assertFalse( + jnp.allclose( + base["decoder"]["layers"]["self_attention"]["query"]["kernel"], + base_orig["decoder"]["layers"]["self_attention"]["query"]["kernel"], + ) + ) + # The key and out kernels are non-targets and must be untouched. + np.testing.assert_array_equal( + np.asarray(base["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + ) + np.testing.assert_array_equal( + np.asarray(base["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + ) + unapply_lora_from_base_params_nnx(base, lora, lora_scale_factor=0.5) + np.testing.assert_allclose( + np.asarray(base["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(base["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + np.asarray(base_orig["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + + def test_numerical_parity_with_linen_apply(self): + """The NNX and Linen apply paths produce identical results on the same inputs.""" + rng = jax.random.key(123) + base_nnx = _concrete_base(rng) + base_linen = {"params": jax.tree_util.tree_map(jnp.copy, base_nnx)} + lora = _build_lora_params(base_nnx, _lora_config(rank=2, target_modules=("q_proj",)), jax.random.key(5)) + apply_lora_on_base_params_nnx(base_nnx, lora, lora_scale_factor=0.7) + apply_lora_on_base_params(base_linen, {"params": lora}, lora_scale_factor=0.7) + np.testing.assert_allclose( + np.asarray(base_nnx["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_linen["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-6, + ) + + def test_apply_with_unexpected_lora_key_raises(self): + base = _concrete_base() + bad = {"decoder": {"layers": {"self_attention": {"query": {"unexpected": jnp.zeros((4, 2))}}}}} + with self.assertRaises(ValueError): + apply_lora_on_base_params_nnx(base, bad) + + +class TestLinenLoraRegression(unittest.TestCase): + """Smoke tests for the Linen apply / unapply helpers (no other unit test exercises them).""" + + def _linen_pair(self, rng=None): + """Build a Linen-shape (with `{"params": ...}` outer wrapper) base + lora pair.""" + if rng is None: + rng = jax.random.key(99) + base_inner = _concrete_base(rng) + base = {"params": jax.tree_util.tree_map(jnp.copy, base_inner)} + lora_inner = _build_lora_params( + base_inner, + _lora_config(rank=2, target_modules=("q_proj", "v_proj")), + jax.random.key(7), + ) + lora = {"params": lora_inner} + return base, lora + + def test_linen_apply_then_unapply_is_identity(self): + base, lora = self._linen_pair() + base_orig = jax.tree_util.tree_map(jnp.copy, base) + apply_lora_on_base_params(base, lora, lora_scale_factor=0.5) + unapply_lora_from_base_params(base, lora, lora_scale_factor=0.5) + np.testing.assert_allclose( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + np.testing.assert_allclose( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["value"]["kernel"]), + rtol=1e-5, + atol=1e-6, + ) + + def test_linen_apply_only_modifies_target_modules(self): + base, lora = self._linen_pair() + base_orig = jax.tree_util.tree_map(jnp.copy, base) + apply_lora_on_base_params(base, lora, lora_scale_factor=1.0) + # query and value are targets and must change. + self.assertFalse( + jnp.allclose( + base["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"], + base_orig["params"]["decoder"]["layers"]["self_attention"]["query"]["kernel"], + ) + ) + # key and out are non-target and must be untouched. + np.testing.assert_array_equal( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["key"]["kernel"]), + ) + np.testing.assert_array_equal( + np.asarray(base["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + np.asarray(base_orig["params"]["decoder"]["layers"]["self_attention"]["out"]["kernel"]), + ) + + +if __name__ == "__main__": + unittest.main()