Skip to content

Commit da02ec8

Browse files
committed
NNX: native LoRA + GRPO (drop maxengine LoRA carve-out, drop GRPO pure_nnx warning)
1 parent 7c68a9d commit da02ec8

7 files changed

Lines changed: 1213 additions & 82 deletions

File tree

src/maxtext/experimental/rl/grpo_trainer.py

Lines changed: 326 additions & 47 deletions
Large diffs are not rendered by default.

src/maxtext/experimental/rl/grpo_utils.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import jaxtyping
2222
from typing import Any, Callable
2323

24+
from flax import nnx
25+
2426
from maxtext.common.common_types import DecoderBlockType
25-
from maxtext.inference.offline_engine import InputData
2627
from maxtext.utils import max_logging
2728
from maxtext.utils import max_utils
2829

@@ -112,6 +113,63 @@ def compute_log_probs(
112113
return token_log_probs, intermediate_outputs
113114

114115

116+
def compute_log_probs_nnx(
117+
model,
118+
inputs,
119+
inputs_position,
120+
inputs_segmentation,
121+
completion_segmentation,
122+
config,
123+
is_train=False,
124+
):
125+
"""Compute per-token log-probabilities for an NNX policy.
126+
127+
Mirrors `compute_log_probs` but takes an `nnx.Module` directly: the model
128+
carries its own parameters and RNG state, so there is no `params` or
129+
`rngs` argument. Intermediate values sown by the forward pass are read
130+
back via `nnx.state(model, nnx.Intermediate)`.
131+
132+
Args:
133+
model: Policy `nnx.Module`.
134+
inputs: A `[B, L]` array of input token ids.
135+
inputs_position: A `[B, L]` array of token positions.
136+
inputs_segmentation: A `[B, L]` array of segment ids.
137+
completion_segmentation: A `[B, L]` array that masks the completion
138+
portion of the sequence.
139+
config: Training configuration object.
140+
is_train: Whether to run the forward in training mode.
141+
142+
Returns:
143+
A tuple `(token_log_probs, intermediate_outputs)` where
144+
`token_log_probs` has shape `[B, L-1]`.
145+
"""
146+
logits = model(
147+
decoder_input_tokens=inputs,
148+
decoder_positions=inputs_position,
149+
decoder_segment_ids=inputs_segmentation,
150+
enable_dropout=(config.enable_dropout if is_train else False),
151+
)
152+
intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict()
153+
logits = logits / config.decode_sampling_temperature
154+
155+
targets = inputs[:, 1:]
156+
shifted_completion_segmentation = jax.lax.dynamic_slice(
157+
completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1)
158+
)
159+
shifted_completion_segmentation = jnp.pad(
160+
shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0
161+
)
162+
mask = shifted_completion_segmentation[..., None]
163+
mask = jnp.broadcast_to(mask, logits.shape)
164+
masked_logits = jnp.where(mask, logits, -jnp.inf)
165+
log_probs = jax.nn.log_softmax(masked_logits, axis=-1)
166+
log_probs = jnp.where(mask, log_probs, -0.0)
167+
log_probs = log_probs[:, :-1, :]
168+
token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0]
169+
token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1]
170+
return token_log_probs, intermediate_outputs
171+
172+
115173
def generate_offline_completions(config, tokenizer_model, inference_engine, data):
116174
"""Generates completions for a batch of prompts using an offline engine.
117175
@@ -125,6 +183,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
125183
The input `data` dictionary updated with the generated completions,
126184
segmentations, positions, and log-probabilities.
127185
"""
186+
# Lazy import: pulls in maxengine and jetstream stubs, which we only want to
187+
# touch when this function is actually called (i.e. during a real GRPO run).
188+
from maxtext.inference.offline_engine import InputData # pylint: disable=import-outside-toplevel
189+
128190
data[config.train_data_columns] = np.asarray(
129191
jnp.repeat(data[config.train_data_columns], config.num_generations, axis=0)
130192
)
@@ -175,6 +237,40 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
175237
return data
176238

177239

240+
def pathways_reshard_nnx(
241+
config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model
242+
):
243+
"""Reshard NNX policy params onto the inference mesh.
244+
245+
Splits the policy `nnx.Param` state out of the training-side TrainStateNNX
246+
model substate, reshards it onto the inference mesh, and pushes the
247+
resharded params into the inference engine. Requires `scan_layers=True`;
248+
the Linen `unscan_train_state_params` helper has no NNX equivalent yet.
249+
250+
Args:
251+
config: Training configuration object.
252+
inference_engine: Inference engine to receive the resharded params.
253+
policy_state_model: Training-side `state.model` substate.
254+
source_shardings_model: Shardings for `policy_state_model`. Unused
255+
because the same shardings are already attached to the params.
256+
destination_shardings_model: Shardings for the inference-side model.
257+
"""
258+
if not config.scan_layers:
259+
raise NotImplementedError(
260+
"GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False."
261+
)
262+
policy_params = nnx.state(policy_state_model, nnx.Param)
263+
source_param_shardings = nnx.state(source_shardings_model, nnx.Param)
264+
dest_param_shardings = nnx.state(destination_shardings_model, nnx.Param)
265+
del source_param_shardings # Already encoded on policy_params.
266+
with (
267+
jax.transfer_guard_device_to_host("disallow_explicit"),
268+
jax.transfer_guard_host_to_device("disallow_explicit"),
269+
):
270+
resharded_params = reshard_pytree(policy_params, dest_param_shardings)
271+
inference_engine.update_params(resharded_params)
272+
273+
178274
def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings):
179275
"""Reshards model parameters from training to inference sharding.
180276

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -447,12 +447,12 @@ def _load_params_nnx(self, params, rng):
447447
return params_state
448448

449449
def load_single_adapter(self, adapter_path):
450+
"""Load a single LoRA adapter from `adapter_path`.
451+
452+
Expects `adapter_config.json` at the root and adapter weights under
453+
`<adapter_path>/0/items`. The returned `params` follows the same tree
454+
shape as `self.abstract_params` (NNX or Linen, depending on the engine).
450455
"""
451-
Load Single adapter from adapter_path.
452-
Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`.
453-
"""
454-
if self.config.pure_nnx:
455-
raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.")
456456
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
457457
adapter_weights_path = os.path.join(adapter_path, "0", "items")
458458

@@ -475,14 +475,20 @@ def apply_adapter(self, base_params, adapter_config, adapter_params):
475475

476476
lora_rank = int(adapter_config["r"])
477477
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
478-
lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor)
478+
if self.config.pure_nnx:
479+
lora_utils.apply_lora_on_base_params_nnx(base_params, adapter_params, lora_scale_factor)
480+
else:
481+
lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor)
479482

480483
def unapply_adapter(self, base_params, adapter_config, adapter_params):
481484
"""Unapply the adapter params from the merged params to get back the base params."""
482485

483486
lora_rank = int(adapter_config["r"])
484487
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
485-
lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor)
488+
if self.config.pure_nnx:
489+
lora_utils.unapply_lora_from_base_params_nnx(base_params, adapter_params, lora_scale_factor)
490+
else:
491+
lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor)
486492

487493
def quantize_params(self, state, rng: PRNGKeyType | None = None):
488494
"""Forward pass to quantize decode params."""

0 commit comments

Comments
 (0)