2121import jaxtyping
2222from typing import Any , Callable
2323
24+ from flax import nnx
25+
2426from maxtext .common .common_types import DecoderBlockType
25- from maxtext .inference .offline_engine import InputData
2627from maxtext .utils import max_logging
2728from maxtext .utils import max_utils
2829
@@ -112,6 +113,63 @@ def compute_log_probs(
112113 return token_log_probs , intermediate_outputs
113114
114115
116+ def compute_log_probs_nnx (
117+ model ,
118+ inputs ,
119+ inputs_position ,
120+ inputs_segmentation ,
121+ completion_segmentation ,
122+ config ,
123+ is_train = False ,
124+ ):
125+ """Compute per-token log-probabilities for an NNX policy.
126+
127+ Mirrors `compute_log_probs` but takes an `nnx.Module` directly: the model
128+ carries its own parameters and RNG state, so there is no `params` or
129+ `rngs` argument. Intermediate values sown by the forward pass are read
130+ back via `nnx.state(model, nnx.Intermediate)`.
131+
132+ Args:
133+ model: Policy `nnx.Module`.
134+ inputs: A `[B, L]` array of input token ids.
135+ inputs_position: A `[B, L]` array of token positions.
136+ inputs_segmentation: A `[B, L]` array of segment ids.
137+ completion_segmentation: A `[B, L]` array that masks the completion
138+ portion of the sequence.
139+ config: Training configuration object.
140+ is_train: Whether to run the forward in training mode.
141+
142+ Returns:
143+ A tuple `(token_log_probs, intermediate_outputs)` where
144+ `token_log_probs` has shape `[B, L-1]`.
145+ """
146+ logits = model (
147+ decoder_input_tokens = inputs ,
148+ decoder_positions = inputs_position ,
149+ decoder_segment_ids = inputs_segmentation ,
150+ enable_dropout = (config .enable_dropout if is_train else False ),
151+ )
152+ intermediate_outputs = nnx .state (model , nnx .Intermediate ).to_pure_dict ()
153+ logits = logits / config .decode_sampling_temperature
154+
155+ targets = inputs [:, 1 :]
156+ shifted_completion_segmentation = jax .lax .dynamic_slice (
157+ completion_segmentation , (0 , 1 ), (completion_segmentation .shape [0 ], completion_segmentation .shape [1 ] - 1 )
158+ )
159+ shifted_completion_segmentation = jnp .pad (
160+ shifted_completion_segmentation , ((0 , 0 ), (0 , 1 )), mode = "constant" , constant_values = 0
161+ )
162+ mask = shifted_completion_segmentation [..., None ]
163+ mask = jnp .broadcast_to (mask , logits .shape )
164+ masked_logits = jnp .where (mask , logits , - jnp .inf )
165+ log_probs = jax .nn .log_softmax (masked_logits , axis = - 1 )
166+ log_probs = jnp .where (mask , log_probs , - 0.0 )
167+ log_probs = log_probs [:, :- 1 , :]
168+ token_log_probs = jnp .take_along_axis (log_probs , targets [..., None ], axis = - 1 )[..., 0 ]
169+ token_log_probs = token_log_probs * shifted_completion_segmentation [:, :- 1 ]
170+ return token_log_probs , intermediate_outputs
171+
172+
115173def generate_offline_completions (config , tokenizer_model , inference_engine , data ):
116174 """Generates completions for a batch of prompts using an offline engine.
117175
@@ -125,6 +183,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
125183 The input `data` dictionary updated with the generated completions,
126184 segmentations, positions, and log-probabilities.
127185 """
186+ # Lazy import: pulls in maxengine and jetstream stubs, which we only want to
187+ # touch when this function is actually called (i.e. during a real GRPO run).
188+ from maxtext .inference .offline_engine import InputData # pylint: disable=import-outside-toplevel
189+
128190 data [config .train_data_columns ] = np .asarray (
129191 jnp .repeat (data [config .train_data_columns ], config .num_generations , axis = 0 )
130192 )
@@ -175,6 +237,40 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
175237 return data
176238
177239
240+ def pathways_reshard_nnx (
241+ config , inference_engine , policy_state_model , source_shardings_model , destination_shardings_model
242+ ):
243+ """Reshard NNX policy params onto the inference mesh.
244+
245+ Splits the policy `nnx.Param` state out of the training-side TrainStateNNX
246+ model substate, reshards it onto the inference mesh, and pushes the
247+ resharded params into the inference engine. Requires `scan_layers=True`;
248+ the Linen `unscan_train_state_params` helper has no NNX equivalent yet.
249+
250+ Args:
251+ config: Training configuration object.
252+ inference_engine: Inference engine to receive the resharded params.
253+ policy_state_model: Training-side `state.model` substate.
254+ source_shardings_model: Shardings for `policy_state_model`. Unused
255+ because the same shardings are already attached to the params.
256+ destination_shardings_model: Shardings for the inference-side model.
257+ """
258+ if not config .scan_layers :
259+ raise NotImplementedError (
260+ "GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False."
261+ )
262+ policy_params = nnx .state (policy_state_model , nnx .Param )
263+ source_param_shardings = nnx .state (source_shardings_model , nnx .Param )
264+ dest_param_shardings = nnx .state (destination_shardings_model , nnx .Param )
265+ del source_param_shardings # Already encoded on policy_params.
266+ with (
267+ jax .transfer_guard_device_to_host ("disallow_explicit" ),
268+ jax .transfer_guard_host_to_device ("disallow_explicit" ),
269+ ):
270+ resharded_params = reshard_pytree (policy_params , dest_param_shardings )
271+ inference_engine .update_params (resharded_params )
272+
273+
178274def pathways_reshard (config , inference_engine , params , source_shardings , source_mesh , destination_shardings ):
179275 """Reshards model parameters from training to inference sharding.
180276
0 commit comments