Skip to content

Commit 44276d9

Browse files
Merge pull request #4186 from AI-Hypercomputer:feat/nnx-correctness-tests
PiperOrigin-RevId: 934602970
2 parents ecdecc3 + 7784694 commit 44276d9

6 files changed

Lines changed: 334 additions & 139 deletions

File tree

src/maxtext/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,4 @@
4242

4343
Transformer = models.Transformer
4444
transformer_as_linen = models.transformer_as_linen
45+
from_config = model_creation_utils.from_config

tests/assets/logits_generation/generate_grpo_golden_logits.py

Lines changed: 89 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,20 @@
2626

2727
from datasets import load_dataset
2828
from flax import linen as nn
29+
from flax import nnx
2930
import jax
3031
import jax.numpy as jnp
3132
from jax.sharding import Mesh
3233
import jsonlines
3334
from maxtext.configs import pyconfig
3435
from maxtext.utils.globals import MAXTEXT_PKG_DIR, MAXTEXT_TEST_ASSETS_ROOT
3536
from maxtext.common.common_types import Array, MODEL_MODE_TRAIN
36-
from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, generate_completions, grpo_loss_fn
37-
from maxtext.experimental.rl.grpo_utils import compute_log_probs
37+
from maxtext.experimental.rl.grpo_trainer import _merge_grpo_state, generate_completions, grpo_loss_fn, grpo_loss_fn_nnx
38+
from maxtext.experimental.rl.grpo_utils import compute_log_probs, compute_log_probs_nnx
3839
from maxtext.inference.maxengine import maxengine
3940
from maxtext.models import models
4041
from maxtext.utils import maxtext_utils
42+
from maxtext.utils import model_creation_utils
4143
from tests.post_training.integration.grpo_trainer_correctness_test import prepare_maxtext_inputs
4244
import numpy as np
4345
import torch
@@ -46,6 +48,43 @@
4648
from trl import GRPOConfig, GRPOTrainer
4749

4850

51+
def _setup_model(config, mesh, rng):
52+
"""Builds the model, and for NNX a frozen reference clone, dispatching on pure_nnx.
53+
54+
Returns (model, reference_model, state). For NNX the model carries its own params
55+
(from_pretrained loads the checkpoint or inits) and state is None; for Linen the
56+
model is a ToLinen module with a separate decode state.
57+
"""
58+
if config.pure_nnx:
59+
model = model_creation_utils.from_pretrained(config, mesh=mesh, rng_key=rng)
60+
return model, nnx.clone(model), None
61+
model = models.transformer_as_linen(config=config, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN)
62+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, None, config, False, rng)
63+
state, state_mesh_annotations = maxtext_utils.setup_decode_state(config, mesh, None, init_state_fn)
64+
return model, None, (state, state_mesh_annotations)
65+
66+
67+
def _logps(config, model, state, ids, pos, seg, comp_seg):
68+
"""Policy per-token log-probs, dispatching between NNX and Linen."""
69+
if config.pure_nnx:
70+
return compute_log_probs_nnx(model, ids, pos, seg, comp_seg, config, is_train=False)
71+
return compute_log_probs(model, state.params, ids, pos, seg, comp_seg, config, is_train=False)
72+
73+
74+
def _reference_logps(config, model, reference_model, reference_params, ids, pos, seg, comp_seg):
75+
"""Reference per-token log-probs. NNX uses the cloned reference model; Linen uses the saved params."""
76+
if config.pure_nnx:
77+
return compute_log_probs_nnx(reference_model, ids, pos, seg, comp_seg, config, is_train=False)
78+
return compute_log_probs(model, {"params": reference_params}, ids, pos, seg, comp_seg, config, is_train=False)
79+
80+
81+
def _grpo_loss(config, model, reference_model, state, reference_params, data, rng):
82+
"""GRPO loss, dispatching between NNX (reference model) and Linen (reference params)."""
83+
if config.pure_nnx:
84+
return grpo_loss_fn_nnx(model, config, data, rng, None, reference_model)
85+
return grpo_loss_fn(model, config, data, rng, state.params, reference_params)
86+
87+
4988
class GRPOTest(unittest.TestCase):
5089

5190
def setUp(self):
@@ -72,28 +111,21 @@ def setUp(self):
72111
self.rng = jax.random.key(self.cfg.init_weights_seed)
73112
devices_array = maxtext_utils.create_device_mesh(self.cfg)
74113
mesh = Mesh(devices_array, self.cfg.mesh_axes)
114+
self.mesh = mesh
75115
# With checkpoint
116+
self.model, self.reference_model, linen_state = _setup_model(self.cfg, mesh, self.rng)
76117
if self.cfg.pure_nnx:
77-
# NNX has a different function to init the training state.
78-
raise NotImplementedError("Pure NNX support has not been implemented yet.")
118+
self.state = None
119+
self.state_mesh_shardings = None # NNX param shardings are derived in the generation step.
79120
else:
80-
self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN)
81-
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng)
82-
self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn)
83-
self.state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, self.cfg.logical_axis_rules)
121+
self.state, state_mesh_annotations = linen_state
122+
self.state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, self.cfg.logical_axis_rules)
84123
self.data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None))
85124
# Without checkpoint
86-
if self.cfg_no_ckpt_loading.pure_nnx:
87-
# NNX has a different function to init the training state.
88-
raise NotImplementedError("Pure NNX support has not been implemented yet.")
89-
else:
90-
self.model_no_ckpt_loading = models.transformer_as_linen(
91-
config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN
92-
)
93-
init_state_fn = functools.partial(
94-
maxtext_utils.init_initial_state, self.model_no_ckpt_loading, None, self.cfg_no_ckpt_loading, False, self.rng
95-
)
96-
self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state(self.cfg_no_ckpt_loading, mesh, None, init_state_fn)
125+
self.model_no_ckpt_loading, self.reference_model_no_ckpt_loading, linen_state_no_ckpt = _setup_model(
126+
self.cfg_no_ckpt_loading, mesh, self.rng
127+
)
128+
self.state_no_ckpt_loading = None if self.cfg_no_ckpt_loading.pure_nnx else linen_state_no_ckpt[0]
97129

98130
self.tokenizer_model = transformers.AutoTokenizer.from_pretrained(
99131
"meta-llama/Llama-3.1-8B",
@@ -181,55 +213,48 @@ def test_w_trl_and_write_golden_data(self):
181213
input_ids, input_segmentation, input_position, completion_segmentation = prepare_maxtext_inputs(
182214
self.cfg.prompt, self.tokenizer_model
183215
)
184-
maxtext_per_token_logps, _ = compute_log_probs(
185-
self.model,
186-
self.state.params,
187-
input_ids,
188-
input_position,
189-
input_segmentation,
190-
completion_segmentation,
191-
self.cfg,
192-
is_train=False,
216+
maxtext_per_token_logps, _ = _logps(
217+
self.cfg, self.model, self.state, input_ids, input_position, input_segmentation, completion_segmentation
193218
)
194219

195-
reference_params = jax.tree.map(jnp.copy, self.state.params["params"])
196-
self.state = _merge_grpo_state(self.state, reference_params)
197-
198-
reference_params_no_ckpt_loading = jax.tree.map(jnp.copy, self.state_no_ckpt_loading.params["params"])
199-
self.state_no_ckpt_loading = _merge_grpo_state(self.state_no_ckpt_loading, reference_params_no_ckpt_loading)
220+
# The reference is a frozen copy of the step-0 policy. NNX holds it as a cloned
221+
# model (built in setUp); Linen snapshots the params and merges them into the state.
222+
reference_params = None
223+
reference_params_no_ckpt_loading = None
224+
if not self.cfg.pure_nnx:
225+
reference_params = jax.tree.map(jnp.copy, self.state.params["params"])
226+
self.state = _merge_grpo_state(self.state, reference_params)
227+
if not self.cfg_no_ckpt_loading.pure_nnx:
228+
reference_params_no_ckpt_loading = jax.tree.map(jnp.copy, self.state_no_ckpt_loading.params["params"])
229+
self.state_no_ckpt_loading = _merge_grpo_state(self.state_no_ckpt_loading, reference_params_no_ckpt_loading)
200230

201231
data = {
202232
"prompt_completions": input_ids,
203233
"prompt_completions_position": input_position,
204234
"prompt_completions_segmentation": input_segmentation,
205235
"ar_completions_segmentation": completion_segmentation,
206236
}
207-
maxtext_loss, aux = grpo_loss_fn(self.model, self.cfg, data, self.rng, self.state.params, reference_params)
237+
maxtext_loss, aux = _grpo_loss(
238+
self.cfg, self.model, self.reference_model, self.state, reference_params, data, self.rng
239+
)
208240
# pylint: disable=protected-access
209241
self.assertEqual(self.trainer._metrics["train"]["kl"][0], aux.avg_kl.tolist())
210242
self.assertEqual(hf_loss.item(), maxtext_loss.tolist())
211243
# since this is on-policy
212244
self.assertEqual(aux.avg_advantage.tolist(), 0.0)
213245
# since we are at step 0
214-
maxtext_per_token_logps, _ = compute_log_probs(
215-
self.model,
216-
self.state.params,
217-
input_ids,
218-
input_position,
219-
input_segmentation,
220-
completion_segmentation,
221-
self.cfg,
222-
is_train=False,
246+
maxtext_per_token_logps, _ = _logps(
247+
self.cfg, self.model, self.state, input_ids, input_position, input_segmentation, completion_segmentation
223248
)
224-
maxtext_per_token_logps_ref, _ = compute_log_probs(
249+
maxtext_per_token_logps_ref, _ = _reference_logps(
250+
self.cfg,
225251
self.model,
226-
{"params": reference_params},
252+
self.reference_model,
253+
reference_params,
227254
input_ids,
228255
input_position,
229256
input_segmentation,
230257
completion_segmentation,
231-
self.cfg,
232-
is_train=False,
233258
)
234259
self.assertTrue(
235260
jax.numpy.allclose(
@@ -243,25 +268,24 @@ def test_w_trl_and_write_golden_data(self):
243268
# Now that we have ensured that the MaxText implementation is correct
244269
# let us create a MaxText model without the checkpoint and save the logits
245270

246-
maxtext_per_token_logps_no_ckpt_loading, _ = compute_log_probs(
271+
maxtext_per_token_logps_no_ckpt_loading, _ = _logps(
272+
self.cfg_no_ckpt_loading,
247273
self.model_no_ckpt_loading,
248-
self.state_no_ckpt_loading.params,
274+
self.state_no_ckpt_loading,
249275
input_ids,
250276
input_position,
251277
input_segmentation,
252278
completion_segmentation,
253-
self.cfg_no_ckpt_loading,
254-
is_train=False,
255-
rngs=self.rng,
256279
)
257280

258-
maxtext_loss, aux = grpo_loss_fn(
259-
self.model_no_ckpt_loading,
281+
maxtext_loss, aux = _grpo_loss(
260282
self.cfg_no_ckpt_loading,
283+
self.model_no_ckpt_loading,
284+
self.reference_model_no_ckpt_loading,
285+
self.state_no_ckpt_loading,
286+
reference_params_no_ckpt_loading,
261287
data,
262288
self.rng,
263-
self.state_no_ckpt_loading.params,
264-
reference_params_no_ckpt_loading,
265289
)
266290

267291
engine = maxengine.MaxEngine(self.cfg_no_ckpt_loading_inference)
@@ -274,14 +298,21 @@ def test_w_trl_and_write_golden_data(self):
274298
)
275299
prompt_true_length = jnp.array([len(prompt_tokens)] * 4)
276300
engine_data = {"prompt": prompt, "prompt_true_length": prompt_true_length}
301+
if self.cfg_no_ckpt_loading.pure_nnx:
302+
# NNX params live on the model; the inference engine is NNX-aware (config.pure_nnx).
303+
gen_params = nnx.state(self.model_no_ckpt_loading, nnx.Param)
304+
gen_param_shardings = jax.tree.map(lambda _: jax.NamedSharding(self.mesh, jax.sharding.PartitionSpec()), gen_params)
305+
else:
306+
gen_params = {"params": self.state_no_ckpt_loading.params["params"]}
307+
gen_param_shardings = self.state_mesh_shardings.params
277308
p_generate_completions: Callable[[dict, dict, Array], Array] = jax.jit(
278309
functools.partial(generate_completions, self.cfg, self.tokenizer_model, engine),
279-
in_shardings=(self.data_sharding, self.state_mesh_shardings.params, None),
310+
in_shardings=(self.data_sharding, gen_param_shardings, None),
280311
out_shardings=self.data_sharding,
281312
donate_argnums=(0,),
282313
)
283314
# pylint: disable=not-callable
284-
engine_data = p_generate_completions(engine_data, {"params": self.state_no_ckpt_loading.params["params"]}, self.rng)
315+
engine_data = p_generate_completions(engine_data, gen_params, self.rng)
285316
data_to_save = {
286317
"maxtext_loss": maxtext_loss.tolist(),
287318
"input_ids": input_ids[0].tolist(),

0 commit comments

Comments
 (0)