Skip to content

Commit 626ce66

Browse files
committed
NNX: native LoRA + GRPO (drop maxengine LoRA carve-out, drop GRPO pure_nnx warning)
1 parent 7e9d8d1 commit 626ce66

7 files changed

Lines changed: 1088 additions & 86 deletions

File tree

src/maxtext/experimental/rl/grpo_trainer.py

Lines changed: 272 additions & 52 deletions
Large diffs are not rendered by default.

src/maxtext/experimental/rl/grpo_utils.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
import jaxtyping
2222
from typing import Any, Callable
2323

24+
from flax import nnx
25+
2426
from maxtext.common.common_types import DecoderBlockType
25-
from maxtext.inference.offline_engine import InputData
2627
from maxtext.utils import max_logging
2728
from maxtext.utils import max_utils
2829

@@ -112,6 +113,48 @@ def compute_log_probs(
112113
return token_log_probs, intermediate_outputs
113114

114115

116+
def compute_log_probs_nnx(
117+
model,
118+
inputs,
119+
inputs_position,
120+
inputs_segmentation,
121+
completion_segmentation,
122+
config,
123+
is_train=False,
124+
):
125+
"""`compute_log_probs` for the NNX path.
126+
127+
`model` is an `nnx.Module` (carries its own params + RNG state), so there's
128+
no `params` arg. Intermediates are pulled off the model after the forward
129+
via `nnx.state(model, nnx.Intermediate).to_pure_dict()`.
130+
"""
131+
logits = model(
132+
decoder_input_tokens=inputs,
133+
decoder_positions=inputs_position,
134+
decoder_segment_ids=inputs_segmentation,
135+
enable_dropout=(config.enable_dropout if is_train else False),
136+
)
137+
intermediate_outputs = nnx.state(model, nnx.Intermediate).to_pure_dict()
138+
logits = logits / config.decode_sampling_temperature
139+
140+
targets = inputs[:, 1:]
141+
shifted_completion_segmentation = jax.lax.dynamic_slice(
142+
completion_segmentation, (0, 1), (completion_segmentation.shape[0], completion_segmentation.shape[1] - 1)
143+
)
144+
shifted_completion_segmentation = jnp.pad(
145+
shifted_completion_segmentation, ((0, 0), (0, 1)), mode="constant", constant_values=0
146+
)
147+
mask = shifted_completion_segmentation[..., None]
148+
mask = jnp.broadcast_to(mask, logits.shape)
149+
masked_logits = jnp.where(mask, logits, -jnp.inf)
150+
log_probs = jax.nn.log_softmax(masked_logits, axis=-1)
151+
log_probs = jnp.where(mask, log_probs, -0.0)
152+
log_probs = log_probs[:, :-1, :]
153+
token_log_probs = jnp.take_along_axis(log_probs, targets[..., None], axis=-1)[..., 0]
154+
token_log_probs = token_log_probs * shifted_completion_segmentation[:, :-1]
155+
return token_log_probs, intermediate_outputs
156+
157+
115158
def generate_offline_completions(config, tokenizer_model, inference_engine, data):
116159
"""Generates completions for a batch of prompts using an offline engine.
117160
@@ -125,6 +168,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
125168
The input `data` dictionary updated with the generated completions,
126169
segmentations, positions, and log-probabilities.
127170
"""
171+
# Lazy import: pulls in maxengine and jetstream stubs, which we only want to
172+
# touch when this function is actually called (i.e. during a real GRPO run).
173+
from maxtext.inference.offline_engine import InputData # pylint: disable=import-outside-toplevel
174+
128175
data[config.train_data_columns] = np.asarray(
129176
jnp.repeat(data[config.train_data_columns], config.num_generations, axis=0)
130177
)
@@ -175,6 +222,30 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
175222
return data
176223

177224

225+
def pathways_reshard_nnx(
226+
config, inference_engine, policy_state_model, source_shardings_model, destination_shardings_model
227+
):
228+
"""`pathways_reshard` for the NNX path.
229+
230+
Reshard the policy params onto the inference mesh and push them into the
231+
inference engine. Requires `scan_layers=True` (no NNX-aware unscan helper yet).
232+
"""
233+
if not config.scan_layers:
234+
raise NotImplementedError(
235+
"GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False."
236+
)
237+
_, policy_params, _ = nnx.split(policy_state_model, nnx.Param, ...)
238+
_, source_param_shardings, _ = nnx.split(source_shardings_model, nnx.Param, ...)
239+
_, dest_param_shardings, _ = nnx.split(destination_shardings_model, nnx.Param, ...)
240+
del source_param_shardings # already encoded on policy_params
241+
with (
242+
jax.transfer_guard_device_to_host("disallow_explicit"),
243+
jax.transfer_guard_host_to_device("disallow_explicit"),
244+
):
245+
resharded_params = reshard_pytree(policy_params, dest_param_shardings)
246+
inference_engine.update_params(resharded_params)
247+
248+
178249
def pathways_reshard(config, inference_engine, params, source_shardings, source_mesh, destination_shardings):
179250
"""Reshards model parameters from training to inference sharding.
180251

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,11 @@ def _load_params_nnx(self, params, rng):
438438
return params_state
439439

440440
def load_single_adapter(self, adapter_path):
441+
"""Load a single LoRA adapter from `adapter_path`.
442+
443+
Expects `adapter_config.json` plus adapter weights at `<adapter_path>/0/items`.
444+
The returned `params` shape matches `self.abstract_params` (NNX or Linen).
441445
"""
442-
Load Single adapter from adapter_path.
443-
Expect adapter_config.json and LoRA adapter weights at this path within subdirectory `/0/items`.
444-
"""
445-
if self.config.pure_nnx:
446-
raise NotImplementedError("pure_nnx + LoRA not yet supported. Use pure_nnx=False.")
447446
adapter_config_path = os.path.join(adapter_path, "adapter_config.json")
448447
adapter_weights_path = os.path.join(adapter_path, "0", "items")
449448

@@ -466,14 +465,20 @@ def apply_adapter(self, base_params, adapter_config, adapter_params):
466465

467466
lora_rank = int(adapter_config["r"])
468467
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
469-
lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor)
468+
if self.config.pure_nnx:
469+
lora_utils.apply_lora_on_base_params_nnx(base_params, adapter_params, lora_scale_factor)
470+
else:
471+
lora_utils.apply_lora_on_base_params(base_params, adapter_params, lora_scale_factor)
470472

471473
def unapply_adapter(self, base_params, adapter_config, adapter_params):
472474
"""Unapply the adapter params from the merged params to get back the base params."""
473475

474476
lora_rank = int(adapter_config["r"])
475477
lora_scale_factor = float(adapter_config["lora_alpha"]) / lora_rank
476-
lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor)
478+
if self.config.pure_nnx:
479+
lora_utils.unapply_lora_from_base_params_nnx(base_params, adapter_params, lora_scale_factor)
480+
else:
481+
lora_utils.unapply_lora_from_base_params(base_params, adapter_params, lora_scale_factor)
477482

478483
def quantize_params(self, state, rng: PRNGKeyType | None = None):
479484
"""Forward pass to quantize decode params."""

0 commit comments

Comments
 (0)