Skip to content

Commit 10bfe3f

Browse files
committed
NNX: native LoRA + GRPO (drop maxengine LoRA carve-out, drop GRPO pure_nnx warning)
1 parent 4dc3ae2 commit 10bfe3f

7 files changed

Lines changed: 1071 additions & 84 deletions

File tree

src/maxtext/experimental/rl/grpo_trainer.py

Lines changed: 269 additions & 51 deletions
Large diffs are not rendered by default.

src/maxtext/experimental/rl/grpo_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import jaxtyping
2222
from typing import Any, Callable
2323

24+
from flax import nnx
25+
2426
from maxtext.common.common_types import DecoderBlockType
2527
from maxtext.inference.offline_engine import InputData
2628
from maxtext.utils import max_logging
@@ -112,6 +114,48 @@ def compute_log_probs(
112114
return token_log_probs, intermediate_outputs
113115

114116

117+
def compute_log_probs_nnx(
118+
model,
119+
inputs,
120+
inputs_position,
121+
inputs_segmentation,
122+
completion_segmentation,
123+
config,
124+
is_train=False,
125+
):
126+
"""`compute_log_probs` for the NNX path.
127+
128+
`model` is an `nnx.Module` (carries its own params + RNG state), so there's
129+
no `params` arg. Intermediates are pulled off the model after the forward
130+
via `nnx.state(model, nnx.Intermediate).to_pure_dict()`.
131+
"""
132+
logits = model(
133+
decoder_input_tokens=inputs,
134+
decoder_positions=inputs_position,
135+
decoder_segment_ids=inputs_segmentation,
136+
enable_dropout=(config.enable_dropout if is_train else False),
137+
)
138+
intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict()
139+
logits = logits / config.decode_sampling_temperature
140+
141+
targets = inputs[:, 1:]
142+
shifted_completion_segmentation = jax.lax.dynamic_slice(
143+
completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1)
144+
)
145+
shifted_completion_segmentation = jnp.pad(
146+
shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0
147+
)
148+
mask = shifted_completion_segmentation[..., None]
149+
mask = jnp.broadcast_to(mask, logits.shape)
150+
masked_logits = jnp.where(mask, logits, -jnp.inf)
151+
log_probs = jax.nn.log_softmax(masked_logits, axis=-1)
152+
log_probs = jnp.where(mask, log_probs, -0.0)
153+
log_probs = log_probs[:, :-1, :]
154+
token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0]
155+
token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1]
156+
return token_log_probs, intermediate_outputs
157+
158+
115159
def generate_offline_completions(config, tokenizer_model, inference_engine, data):
116160
"""Generates completions for a batch of prompts using an offline engine.
117161
@@ -175,6 +219,30 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
175219
return data
176220

177221

222+
def pathways_reshard_nnx(
223+
config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model
224+
):
225+
"""`pathways_reshard` for the NNX path.
226+
227+
Reshard the policy params onto the inference mesh and push them into the
228+
inference engine. Requires `scan_layers=True` (no NNX-aware unscan helper yet).
229+
"""
230+
if not config.scan_layers:
231+
raise NotImplementedError(
232+
"GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False."
233+
)
234+
_, policy_params, _ = nnx.split(policy_state_model, nnx.Param, ...)
235+
_, source_param_shardings, _ = nnx.split(source_shardings_model, nnx.Param, ...)
236+
_, dest_param_shardings, _ = nnx.split(destination_shardings_model, nnx.Param, ...)
237+
del source_param_shardings # already encoded on policy_params
238+
with (
239+
jax.transfer_guard_device_to_host("disallow_explicit"),
240+
jax.transfer_guard_host_to_device("disallow_explicit"),
241+
):
242+
resharded_params = reshard_pytree(policy_params, dest_param_shardings)
243+
inference_engine.update_params(resharded_params)
244+
245+
178246
def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings):
179247
"""Reshards model parameters from training to inference sharding.
180248

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,11 @@ def _load_params_nnx(self, params, rng):
438438
return params_state
439439

440440
def load_single_adapter(self, adapter_path):
441+
"""Load a single LoRA adapter from `adapter_path`.
442+
443+
Expects `adapter_config.json` plus adapter weights at `<adapter_path>/0/items`.
444+
The returned `params` shape matches `self.abstract_params` (NNX or Linen).
441445
"""
442-
Load Single adapter from adapter_path.
443-
Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`.
444-
"""
445-
if self.config.pure_nnx:
446-
raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.")
447446
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
448447
adapter_weights_path = os.path.join(adapter_path, "0", "items")
449448

@@ -466,14 +465,20 @@ def apply_adapter(self, base_params, adapter_config, adapter_params):
466465

467466
lora_rank = int(adapter_config["r"])
468467
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
469-
lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor)
468+
if self.config.pure_nnx:
469+
lora_utils.apply_lora_on_base_params_nnx(base_params, adapter_params, lora_scale_factor)
470+
else:
471+
lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor)
470472

471473
def unapply_adapter(self, base_params, adapter_config, adapter_params):
472474
"""Unapply the adapter params from the merged params to get back the base params."""
473475

474476
lora_rank = int(adapter_config["r"])
475477
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
476-
lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor)
478+
if self.config.pure_nnx:
479+
lora_utils.unapply_lora_from_base_params_nnx(base_params, adapter_params, lora_scale_factor)
480+
else:
481+
lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor)
477482

478483
def quantize_params(self, state, rng: PRNGKeyType | None = None):
479484
"""Forward pass to quantize decode params."""

0 commit comments

Comments
 (0)