2626
2727from datasets import load_dataset
2828from flax import linen as nn
29+ from flax import nnx
2930import jax
3031import jax .numpy as jnp
3132from jax .sharding import Mesh
3233import jsonlines
3334from maxtext .configs import pyconfig
3435from maxtext .utils .globals import MAXTEXT_PKG_DIR , MAXTEXT_TEST_ASSETS_ROOT
3536from maxtext .common .common_types import Array , MODEL_MODE_TRAIN
36- from maxtext .experimental .rl .grpo_trainer import _merge_grpo_state , generate_completions , grpo_loss_fn
37- from maxtext .experimental .rl .grpo_utils import compute_log_probs
37+ from maxtext .experimental .rl .grpo_trainer import _merge_grpo_state , generate_completions , grpo_loss_fn , grpo_loss_fn_nnx
38+ from maxtext .experimental .rl .grpo_utils import compute_log_probs , compute_log_probs_nnx
3839from maxtext .inference .maxengine import maxengine
3940from maxtext .models import models
4041from maxtext .utils import maxtext_utils
42+ from maxtext .utils import model_creation_utils
4143from tests .post_training .integration .grpo_trainer_correctness_test import prepare_maxtext_inputs
4244import numpy as np
4345import torch
4648from trl import GRPOConfig , GRPOTrainer
4749
4850
51+ def _setup_model (config , mesh , rng ):
52+ """Builds the model, and for NNX a frozen reference clone, dispatching on pure_nnx.
53+
54+ Returns (model, reference_model, state). For NNX the model carries its own params
55+ (from_pretrained loads the checkpoint or inits) and state is None; for Linen the
56+ model is a ToLinen module with a separate decode state.
57+ """
58+ if config .pure_nnx :
59+ model = model_creation_utils .from_pretrained (config , mesh = mesh , rng_key = rng )
60+ return model , nnx .clone (model ), None
61+ model = models .transformer_as_linen (config = config , mesh = mesh , quant = None , model_mode = MODEL_MODE_TRAIN )
62+ init_state_fn = functools .partial (maxtext_utils .init_initial_state , model , None , config , False , rng )
63+ state , state_mesh_annotations = maxtext_utils .setup_decode_state (config , mesh , None , init_state_fn )
64+ return model , None , (state , state_mesh_annotations )
65+
66+
67+ def _logps (config , model , state , ids , pos , seg , comp_seg ):
68+ """Policy per-token log-probs, dispatching between NNX and Linen."""
69+ if config .pure_nnx :
70+ return compute_log_probs_nnx (model , ids , pos , seg , comp_seg , config , is_train = False )
71+ return compute_log_probs (model , state .params , ids , pos , seg , comp_seg , config , is_train = False )
72+
73+
74+ def _reference_logps (config , model , reference_model , reference_params , ids , pos , seg , comp_seg ):
75+ """Reference per-token log-probs. NNX uses the cloned reference model; Linen uses the saved params."""
76+ if config .pure_nnx :
77+ return compute_log_probs_nnx (reference_model , ids , pos , seg , comp_seg , config , is_train = False )
78+ return compute_log_probs (model , {"params" : reference_params }, ids , pos , seg , comp_seg , config , is_train = False )
79+
80+
81+ def _grpo_loss (config , model , reference_model , state , reference_params , data , rng ):
82+ """GRPO loss, dispatching between NNX (reference model) and Linen (reference params)."""
83+ if config .pure_nnx :
84+ return grpo_loss_fn_nnx (model , config , data , rng , None , reference_model )
85+ return grpo_loss_fn (model , config , data , rng , state .params , reference_params )
86+
87+
4988class GRPOTest (unittest .TestCase ):
5089
5190 def setUp (self ):
@@ -72,28 +111,21 @@ def setUp(self):
72111 self .rng = jax .random .key (self .cfg .init_weights_seed )
73112 devices_array = maxtext_utils .create_device_mesh (self .cfg )
74113 mesh = Mesh (devices_array , self .cfg .mesh_axes )
114+ self .mesh = mesh
75115 # With checkpoint
116+ self .model , self .reference_model , linen_state = _setup_model (self .cfg , mesh , self .rng )
76117 if self .cfg .pure_nnx :
77- # NNX has a different function to init the training state.
78- raise NotImplementedError ( "Pure NNX support has not been implemented yet." )
118+ self . state = None
119+ self . state_mesh_shardings = None # NNX param shardings are derived in the generation step.
79120 else :
80- self .model = models .transformer_as_linen (config = self .cfg , mesh = mesh , quant = None , model_mode = MODEL_MODE_TRAIN )
81- init_state_fn = functools .partial (maxtext_utils .init_initial_state , self .model , None , self .cfg , False , self .rng )
82- self .state , state_mesh_annotations = maxtext_utils .setup_decode_state (self .cfg , mesh , None , init_state_fn )
83- self .state_mesh_shardings = nn .logical_to_mesh_sharding (state_mesh_annotations , mesh , self .cfg .logical_axis_rules )
121+ self .state , state_mesh_annotations = linen_state
122+ self .state_mesh_shardings = nn .logical_to_mesh_sharding (state_mesh_annotations , mesh , self .cfg .logical_axis_rules )
84123 self .data_sharding = jax .NamedSharding (mesh , jax .sharding .PartitionSpec (None ))
85124 # Without checkpoint
86- if self .cfg_no_ckpt_loading .pure_nnx :
87- # NNX has a different function to init the training state.
88- raise NotImplementedError ("Pure NNX support has not been implemented yet." )
89- else :
90- self .model_no_ckpt_loading = models .transformer_as_linen (
91- config = self .cfg_no_ckpt_loading , mesh = mesh , quant = None , model_mode = MODEL_MODE_TRAIN
92- )
93- init_state_fn = functools .partial (
94- maxtext_utils .init_initial_state , self .model_no_ckpt_loading , None , self .cfg_no_ckpt_loading , False , self .rng
95- )
96- self .state_no_ckpt_loading , _ = maxtext_utils .setup_decode_state (self .cfg_no_ckpt_loading , mesh , None , init_state_fn )
125+ self .model_no_ckpt_loading , self .reference_model_no_ckpt_loading , linen_state_no_ckpt = _setup_model (
126+ self .cfg_no_ckpt_loading , mesh , self .rng
127+ )
128+ self .state_no_ckpt_loading = None if self .cfg_no_ckpt_loading .pure_nnx else linen_state_no_ckpt [0 ]
97129
98130 self .tokenizer_model = transformers .AutoTokenizer .from_pretrained (
99131 "meta-llama/Llama-3.1-8B" ,
@@ -181,55 +213,48 @@ def test_w_trl_and_write_golden_data(self):
181213 input_ids , input_segmentation , input_position , completion_segmentation = prepare_maxtext_inputs (
182214 self .cfg .prompt , self .tokenizer_model
183215 )
184- maxtext_per_token_logps , _ = compute_log_probs (
185- self .model ,
186- self .state .params ,
187- input_ids ,
188- input_position ,
189- input_segmentation ,
190- completion_segmentation ,
191- self .cfg ,
192- is_train = False ,
216+ maxtext_per_token_logps , _ = _logps (
217+ self .cfg , self .model , self .state , input_ids , input_position , input_segmentation , completion_segmentation
193218 )
194219
195- reference_params = jax .tree .map (jnp .copy , self .state .params ["params" ])
196- self .state = _merge_grpo_state (self .state , reference_params )
197-
198- reference_params_no_ckpt_loading = jax .tree .map (jnp .copy , self .state_no_ckpt_loading .params ["params" ])
199- self .state_no_ckpt_loading = _merge_grpo_state (self .state_no_ckpt_loading , reference_params_no_ckpt_loading )
220+ # The reference is a frozen copy of the step-0 policy. NNX holds it as a cloned
221+ # model (built in setUp); Linen snapshots the params and merges them into the state.
222+ reference_params = None
223+ reference_params_no_ckpt_loading = None
224+ if not self .cfg .pure_nnx :
225+ reference_params = jax .tree .map (jnp .copy , self .state .params ["params" ])
226+ self .state = _merge_grpo_state (self .state , reference_params )
227+ if not self .cfg_no_ckpt_loading .pure_nnx :
228+ reference_params_no_ckpt_loading = jax .tree .map (jnp .copy , self .state_no_ckpt_loading .params ["params" ])
229+ self .state_no_ckpt_loading = _merge_grpo_state (self .state_no_ckpt_loading , reference_params_no_ckpt_loading )
200230
201231 data = {
202232 "prompt_completions" : input_ids ,
203233 "prompt_completions_position" : input_position ,
204234 "prompt_completions_segmentation" : input_segmentation ,
205235 "ar_completions_segmentation" : completion_segmentation ,
206236 }
207- maxtext_loss , aux = grpo_loss_fn (self .model , self .cfg , data , self .rng , self .state .params , reference_params )
237+ maxtext_loss , aux = _grpo_loss (
238+ self .cfg , self .model , self .reference_model , self .state , reference_params , data , self .rng
239+ )
208240 # pylint: disable=protected-access
209241 self .assertEqual (self .trainer ._metrics ["train" ]["kl" ][0 ], aux .avg_kl .tolist ())
210242 self .assertEqual (hf_loss .item (), maxtext_loss .tolist ())
211243 # since this is on-policy
212244 self .assertEqual (aux .avg_advantage .tolist (), 0.0 )
213245 # since we are at step 0
214- maxtext_per_token_logps , _ = compute_log_probs (
215- self .model ,
216- self .state .params ,
217- input_ids ,
218- input_position ,
219- input_segmentation ,
220- completion_segmentation ,
221- self .cfg ,
222- is_train = False ,
246+ maxtext_per_token_logps , _ = _logps (
247+ self .cfg , self .model , self .state , input_ids , input_position , input_segmentation , completion_segmentation
223248 )
224- maxtext_per_token_logps_ref , _ = compute_log_probs (
249+ maxtext_per_token_logps_ref , _ = _reference_logps (
250+ self .cfg ,
225251 self .model ,
226- {"params" : reference_params },
252+ self .reference_model ,
253+ reference_params ,
227254 input_ids ,
228255 input_position ,
229256 input_segmentation ,
230257 completion_segmentation ,
231- self .cfg ,
232- is_train = False ,
233258 )
234259 self .assertTrue (
235260 jax .numpy .allclose (
@@ -243,25 +268,24 @@ def test_w_trl_and_write_golden_data(self):
243268 # Now that we have ensured that the MaxText implementation is correct
244269 # let us create a MaxText model without the checkpoint and save the logits
245270
246- maxtext_per_token_logps_no_ckpt_loading , _ = compute_log_probs (
271+ maxtext_per_token_logps_no_ckpt_loading , _ = _logps (
272+ self .cfg_no_ckpt_loading ,
247273 self .model_no_ckpt_loading ,
248- self .state_no_ckpt_loading . params ,
274+ self .state_no_ckpt_loading ,
249275 input_ids ,
250276 input_position ,
251277 input_segmentation ,
252278 completion_segmentation ,
253- self .cfg_no_ckpt_loading ,
254- is_train = False ,
255- rngs = self .rng ,
256279 )
257280
258- maxtext_loss , aux = grpo_loss_fn (
259- self .model_no_ckpt_loading ,
281+ maxtext_loss , aux = _grpo_loss (
260282 self .cfg_no_ckpt_loading ,
283+ self .model_no_ckpt_loading ,
284+ self .reference_model_no_ckpt_loading ,
285+ self .state_no_ckpt_loading ,
286+ reference_params_no_ckpt_loading ,
261287 data ,
262288 self .rng ,
263- self .state_no_ckpt_loading .params ,
264- reference_params_no_ckpt_loading ,
265289 )
266290
267291 engine = maxengine .MaxEngine (self .cfg_no_ckpt_loading_inference )
@@ -274,14 +298,21 @@ def test_w_trl_and_write_golden_data(self):
274298 )
275299 prompt_true_length = jnp .array ([len (prompt_tokens )] * 4 )
276300 engine_data = {"prompt" : prompt , "prompt_true_length" : prompt_true_length }
301+ if self .cfg_no_ckpt_loading .pure_nnx :
302+ # NNX params live on the model; the inference engine is NNX-aware (config.pure_nnx).
303+ gen_params = nnx .state (self .model_no_ckpt_loading , nnx .Param )
304+ gen_param_shardings = jax .tree .map (lambda _ : jax .NamedSharding (self .mesh , jax .sharding .PartitionSpec ()), gen_params )
305+ else :
306+ gen_params = {"params" : self .state_no_ckpt_loading .params ["params" ]}
307+ gen_param_shardings = self .state_mesh_shardings .params
277308 p_generate_completions : Callable [[dict , dict , Array ], Array ] = jax .jit (
278309 functools .partial (generate_completions , self .cfg , self .tokenizer_model , engine ),
279- in_shardings = (self .data_sharding , self . state_mesh_shardings . params , None ),
310+ in_shardings = (self .data_sharding , gen_param_shardings , None ),
280311 out_shardings = self .data_sharding ,
281312 donate_argnums = (0 ,),
282313 )
283314 # pylint: disable=not-callable
284- engine_data = p_generate_completions (engine_data , { "params" : self . state_no_ckpt_loading . params [ "params" ]} , self .rng )
315+ engine_data = p_generate_completions (engine_data , gen_params , self .rng )
285316 data_to_save = {
286317 "maxtext_loss" : maxtext_loss .tolist (),
287318 "input_ids" : input_ids [0 ].tolist (),
0 commit comments