2121import jaxtyping
2222from typing import Any , Callable
2323
24+ from flax import nnx
25+
2426from maxtext .common .common_types import DecoderBlockType
25- from maxtext .inference .offline_engine import InputData
2627from maxtext .utils import max_logging
2728from maxtext .utils import max_utils
2829
@@ -112,6 +113,48 @@ def compute_log_probs(
112113 return token_log_probs , intermediate_outputs
113114
114115
116+ def compute_log_probs_nnx (
117+ model ,
118+ inputs ,
119+ inputs_position ,
120+ inputs_segmentation ,
121+ completion_segmentation ,
122+ config ,
123+ is_train = False ,
124+ ):
125+ """`compute_log_probs` for the NNX path.
126+
127+ `model` is an `nnx.Module` (carries its own params + RNG state), so there's
128+ no `params` arg. Intermediates are pulled off the model after the forward
129+ via `nnx.state(model, nnx.Intermediate).to_pure_dict()`.
130+ """
131+ logits = model (
132+ decoder_input_tokens = inputs ,
133+ decoder_positions = inputs_position ,
134+ decoder_segment_ids = inputs_segmentation ,
135+ enable_dropout = (config .enable_dropout if is_train else False ),
136+ )
137+ intermediate_outputs = nnx .state (model , nnx .Intermediate ).to_pure_dict ()
138+ logits = logits / config .decode_sampling_temperature
139+
140+ targets = inputs [:, 1 :]
141+ shifted_completion_segmentation = jax .lax .dynamic_slice (
142+ completion_segmentation , (0 , 1 ), (completion_segmentation .shape [0 ], completion_segmentation .shape [1 ] - 1 )
143+ )
144+ shifted_completion_segmentation = jnp .pad (
145+ shifted_completion_segmentation , ((0 , 0 ), (0 , 1 )), mode = "constant" , constant_values = 0
146+ )
147+ mask = shifted_completion_segmentation [..., None ]
148+ mask = jnp .broadcast_to (mask , logits .shape )
149+ masked_logits = jnp .where (mask , logits , - jnp .inf )
150+ log_probs = jax .nn .log_softmax (masked_logits , axis = - 1 )
151+ log_probs = jnp .where (mask , log_probs , - 0.0 )
152+ log_probs = log_probs [:, :- 1 , :]
153+ token_log_probs = jnp .take_along_axis (log_probs , targets [..., None ], axis = - 1 )[..., 0 ]
154+ token_log_probs = token_log_probs * shifted_completion_segmentation [:, :- 1 ]
155+ return token_log_probs , intermediate_outputs
156+
157+
115158def generate_offline_completions (config , tokenizer_model , inference_engine , data ):
116159 """Generates completions for a batch of prompts using an offline engine.
117160
@@ -125,6 +168,10 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
125168 The input `data` dictionary updated with the generated completions,
126169 segmentations, positions, and log-probabilities.
127170 """
171+ # Lazy import: pulls in maxengine and jetstream stubs, which we only want to
172+ # touch when this function is actually called (i.e. during a real GRPO run).
173+ from maxtext .inference .offline_engine import InputData # pylint: disable=import-outside-toplevel
174+
128175 data [config .train_data_columns ] = np .asarray (
129176 jnp .repeat (data [config .train_data_columns ], config .num_generations , axis = 0 )
130177 )
@@ -175,6 +222,30 @@ def generate_offline_completions(config, tokenizer_model, inference_engine, data
175222 return data
176223
177224
225+ def pathways_reshard_nnx (
226+ config , inference_engine , policy_state_model , source_shardings_model , destination_shardings_model
227+ ):
228+ """`pathways_reshard` for the NNX path.
229+
230+ Reshard the policy params onto the inference mesh and push them into the
231+ inference engine. Requires `scan_layers=True` (no NNX-aware unscan helper yet).
232+ """
233+ if not config .scan_layers :
234+ raise NotImplementedError (
235+ "GRPO + pure_nnx + scan_layers=False not supported yet. " "Use scan_layers=True or pure_nnx=False."
236+ )
237+ _ , policy_params , _ = nnx .split (policy_state_model , nnx .Param , ...)
238+ _ , source_param_shardings , _ = nnx .split (source_shardings_model , nnx .Param , ...)
239+ _ , dest_param_shardings , _ = nnx .split (destination_shardings_model , nnx .Param , ...)
240+ del source_param_shardings # already encoded on policy_params
241+ with (
242+ jax .transfer_guard_device_to_host ("disallow_explicit" ),
243+ jax .transfer_guard_host_to_device ("disallow_explicit" ),
244+ ):
245+ resharded_params = reshard_pytree (policy_params , dest_param_shardings )
246+ inference_engine .update_params (resharded_params )
247+
248+
178249def pathways_reshard (config , inference_engine , params , source_shardings , source_mesh , destination_shardings ):
179250 """Reshards model parameters from training to inference sharding.
180251
0 commit comments