Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 326 additions & 47 deletions src/maxtext/experimental/rl/grpo_trainer.py

Large diffs are not rendered by default.

98 changes: 97 additions & 1 deletion src/maxtext/experimental/rl/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import jaxtyping
from typing import Any, Callable

from flax import nnx

from maxtext.common.common_types import DecoderBlockType
from maxtext.inference.offline_engine import InputData
from maxtext.utils import max_logging
from maxtext.utils import max_utils

Expand Down Expand Up @@ -112,6 +113,63 @@ def compute_log_probs(
return token_log_probs, intermediate_outputs


def compute_log_probs_nnx(
model,
inputs,
inputs_position,
inputs_segmentation,
completion_segmentation,
config,
is_train=False,
):
"""Compute per-token log-probabilities for an NNX policy.

Mirrors `compute_log_probs` but takes an `nnx.Module` directly: the model
carries its own parameters and RNG state, so there is no `params` or
`rngs` argument. Intermediate values sown by the forward pass are read
back via `nnx.state(model, nnx.Intermediate)`.

Args:
model: Policy `nnx.Module`.
inputs: A `[B, L]` array of input token ids.
inputs_position: A `[B, L]` array of token positions.
inputs_segmentation: A `[B, L]` array of segment ids.
completion_segmentation: A `[B, L]` array that masks the completion
portion of the sequence.
config: Training configuration object.
is_train: Whether to run the forward in training mode.

Returns:
A tuple `(token_log_probs, intermediate_outputs)` where
`token_log_probs` has shape `[B, L-1]`.
"""
logits = model(
decoder_input_tokens=inputs,
decoder_positions=inputs_position,
decoder_segment_ids=inputs_segmentation,
enable_dropout=(config.enable_dropout if is_train else False),
)
intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict()
logits = logits / config.decode_sampling_temperature

targets = inputs[:, 1:]
shifted_completion_segmentation = jax.lax.dynamic_slice(
completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1)
)
shifted_completion_segmentation = jnp.pad(
shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0
)
mask = shifted_completion_segmentation[..., None]
mask = jnp.broadcast_to(mask, logits.shape)
masked_logits = jnp.where(mask, logits, -jnp.inf)
log_probs = jax.nn.log_softmax(masked_logits, axis=-1)
log_probs = jnp.where(mask, log_probs, -0.0)
log_probs = log_probs[:, :-1, :]
token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0]
token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1]
return token_log_probs, intermediate_outputs


def generate_offline_completions(config, tokenizer_model, inference_engine, data):
"""Generates completions for a batch of prompts using an offline engine.

Expand All @@ -125,6 +183,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
The input `data` dictionary updated with the generated completions,
segmentations, positions, and log-probabilities.
"""
# Lazy import: pulls in maxengine and jetstream stubs, which we only want to
# touch when this function is actually called (i.e. during a real GRPO run).
from maxtext.inference.offline_engine import InputData # pylint: disable=import-outside-toplevel

data[config.train_data_columns] = np.asarray(
jnp.repeat(data[config.train_data_columns], config.num_generations, axis=0)
)
Expand Down Expand Up @@ -175,6 +237,40 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
return data


def pathways_reshard_nnx(
config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model
):
"""Reshard NNX policy params onto the inference mesh.

Splits the policy `nnx.Param` state out of the training-side TrainStateNNX
model substate, reshards it onto the inference mesh, and pushes the
resharded params into the inference engine. Requires `scan_layers=True`;
the Linen `unscan_train_state_params` helper has no NNX equivalent yet.

Args:
config: Training configuration object.
inference_engine: Inference engine to receive the resharded params.
policy_state_model: Training-side `state.model` substate.
source_shardings_model: Shardings for `policy_state_model`. Unused
because the same shardings are already attached to the params.
destination_shardings_model: Shardings for the inference-side model.
"""
if not config.scan_layers:
raise NotImplementedError(
"GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False."
)
policy_params = nnx.state(policy_state_model, nnx.Param)
source_param_shardings = nnx.state(source_shardings_model, nnx.Param)
dest_param_shardings = nnx.state(destination_shardings_model, nnx.Param)
del source_param_shardings # Already encoded on policy_params.
with (
jax.transfer_guard_device_to_host("disallow_explicit"),
jax.transfer_guard_host_to_device("disallow_explicit"),
):
resharded_params = reshard_pytree(policy_params, dest_param_shardings)
inference_engine.update_params(resharded_params)


def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings):
"""Reshards model parameters from training to inference sharding.

Expand Down
20 changes: 13 additions & 7 deletions src/maxtext/inference/maxengine/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,12 @@ def _load_params_nnx(self, params, rng):
return params_state

def load_single_adapter(self, adapter_path):
"""Load a single LoRA adapter from `adapter_path`.

Expects `adapter_config.json` at the root and adapter weights under
`<adapter_path>/0/items`. The returned `params` follows the same tree
shape as `self.abstract_params` (NNX or Linen, depending on the engine).
"""
Load Single adapter from adapter_path.
Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`.
"""
if self.config.pure_nnx:
raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.")
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
adapter_weights_path = os.path.join(adapter_path, "0", "items")

Expand All @@ -475,14 +475,20 @@ def apply_adapter(self, base_params, adapter_config, adapter_params):

lora_rank = int(adapter_config["r"])
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor)
if self.config.pure_nnx:
lora_utils.apply_lora_on_base_params_nnx(base_params, adapter_params, lora_scale_factor)
else:
lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor)

def unapply_adapter(self, base_params, adapter_config, adapter_params):
"""Unapply the adapter params from the merged params to get back the base params."""

lora_rank = int(adapter_config["r"])
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor)
if self.config.pure_nnx:
lora_utils.unapply_lora_from_base_params_nnx(base_params, adapter_params, lora_scale_factor)
else:
lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor)

def quantize_params(self, state, rng: PRNGKeyType | None = None):
"""Forward pass to quantize decode params."""
Expand Down
Loading
Loading