Skip to content

Commit f19da9a

Browse files
committed
[NNX] Delete Linen (3/4): drop obsolete Linen tests and flag references
Remove obsolete Linen-only tests, drop redundant flag args from the rest, and compile the hlo_diff tests via base.yml + model_name so they exercise the real NNX path.
1 parent 2b50bf9 commit f19da9a

25 files changed

Lines changed: 208 additions & 1543 deletions

tests/assets/logits_generation/generate_grpo_golden_logits.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,18 @@ def setUp(self):
7373
devices_array = maxtext_utils.create_device_mesh(self.cfg)
7474
mesh = Mesh(devices_array, self.cfg.mesh_axes)
7575
# With checkpoint
76-
if self.cfg.pure_nnx:
77-
# NNX has a different function to init the training state.
78-
raise NotImplementedError("Pure NNX support has not been implemented yet.")
79-
else:
80-
self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN)
81-
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng)
76+
self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN)
77+
init_state_fn = functools.partial(maxtext_utils.init_initial_state, self.model, None, self.cfg, False, self.rng)
8278
self.state, state_mesh_annotations = maxtext_utils.setup_decode_state(self.cfg, mesh, None, init_state_fn)
8379
self.state_mesh_shardings = nn.logical_to_mesh_sharding(state_mesh_annotations, mesh, self.cfg.logical_axis_rules)
8480
self.data_sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec(None))
8581
# Without checkpoint
86-
if self.cfg_no_ckpt_loading.pure_nnx:
87-
# NNX has a different function to init the training state.
88-
raise NotImplementedError("Pure NNX support has not been implemented yet.")
89-
else:
90-
self.model_no_ckpt_loading = models.transformer_as_linen(
91-
config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN
92-
)
93-
init_state_fn = functools.partial(
94-
maxtext_utils.init_initial_state, self.model_no_ckpt_loading, None, self.cfg_no_ckpt_loading, False, self.rng
95-
)
82+
self.model_no_ckpt_loading = models.transformer_as_linen(
83+
config=self.cfg_no_ckpt_loading, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN
84+
)
85+
init_state_fn = functools.partial(
86+
maxtext_utils.init_initial_state, self.model_no_ckpt_loading, None, self.cfg_no_ckpt_loading, False, self.rng
87+
)
9688
self.state_no_ckpt_loading, _ = maxtext_utils.setup_decode_state(self.cfg_no_ckpt_loading, mesh, None, init_state_fn)
9789

9890
self.tokenizer_model = transformers.AutoTokenizer.from_pretrained(

tests/integration/deepseek_scan_engram_test.py

Lines changed: 10 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,19 @@
1414

1515
"""Unit tests for DeepSeek Engram across scanned decoder layers."""
1616

17-
import gc
18-
import os
1917
import unittest
2018
from unittest.mock import patch
2119

22-
import jax
2320
import jax.numpy as jnp
24-
from jax.sharding import Mesh
2521

26-
from maxtext.configs import pyconfig
27-
from maxtext.utils.globals import MAXTEXT_PKG_DIR
28-
from maxtext.common.common_types import MODEL_MODE_TRAIN
29-
from maxtext.layers.decoders import Decoder
30-
from maxtext.utils import maxtext_utils
3122
import pytest
3223

24+
# The Linen Decoder this test exercised was removed in PR12 (Delete Linen).
25+
# NNX decoder coverage is in tests/unit/nnx_decoders_test.py.
26+
pytestmark = pytest.mark.skip(
27+
reason="Linen Decoder removed in PR12 (Delete Linen); NNX decoder coverage is in tests/unit/nnx_decoders_test.py"
28+
)
29+
3330

3431
class DummyEmbedding:
3532
"""Dummy embedding layer for testing."""
@@ -91,81 +88,10 @@ def _test_engram_pattern(
9188
base_num_decoder_layers=10,
9289
):
9390
"""Helper method to test different engram layer patterns."""
94-
95-
# Setup mock tokenizer
96-
class MockTokenizer:
97-
"""Mock tokenizer for testing."""
98-
99-
pad_token_id = 0
100-
101-
def __len__(self):
102-
return 128
103-
104-
def __call__(self, x):
105-
return jnp.ones_like(x)
106-
107-
def convert_ids_to_tokens(self, *args, **kwargs):
108-
return "a"
109-
110-
def decode(self, *args, **kwargs):
111-
return "a"
112-
113-
def batch_decode(self, token_ids, *args, **kwargs):
114-
return ["a" for _ in token_ids]
115-
116-
mock_from_pretrained.return_value = MockTokenizer()
117-
118-
config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")
119-
config = pyconfig.initialize(
120-
[None, config_path]
121-
+ self._COMMON_CONFIG
122-
+ [
123-
f"engram_layers=[{engram_layers_str}]",
124-
f"first_num_dense_layers={first_num_dense_layers}",
125-
f"base_num_decoder_layers={base_num_decoder_layers}",
126-
f"num_decoder_layers={base_num_decoder_layers}",
127-
]
128-
)
129-
130-
devices_array = maxtext_utils.create_device_mesh(config)
131-
mesh = Mesh(devices_array, config.mesh_axes)
132-
133-
decoder = Decoder(
134-
config=config,
135-
mesh=mesh,
136-
model_mode=MODEL_MODE_TRAIN,
137-
)
138-
139-
batch_size = config.global_batch_size_to_load
140-
seq_len = config.max_target_length
141-
142-
decoder_input_tokens = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
143-
decoder_positions = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
144-
decoder_segment_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)
145-
146-
shared_embedding = DummyEmbedding(emb_dim=config.emb_dim)
147-
148-
with jax.set_mesh(mesh), jax.disable_jit():
149-
variables = decoder.init(
150-
{"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1), "aqt": jax.random.PRNGKey(2)},
151-
shared_embedding=shared_embedding,
152-
decoder_input_tokens=decoder_input_tokens,
153-
decoder_positions=decoder_positions,
154-
decoder_segment_ids=decoder_segment_ids,
155-
deterministic=True,
156-
model_mode=MODEL_MODE_TRAIN,
157-
)
158-
159-
self.assertIn("params", variables)
160-
params = variables["params"]
161-
for key in expected_keys:
162-
self.assertIn(key, params)
163-
164-
del variables
165-
del params
166-
del decoder
167-
jax.clear_caches()
168-
gc.collect()
91+
# The Linen Decoder this exercised was removed in PR12 (Delete Linen);
92+
# NNX decoder coverage lives in tests/unit/nnx_decoders_test.py.
93+
del mock_from_pretrained, engram_layers_str, expected_keys, first_num_dense_layers, base_num_decoder_layers
94+
raise unittest.SkipTest("Linen Decoder removed in PR12 (Delete Linen)")
16995

17096
@pytest.mark.tpu_only
17197
@patch("transformers.AutoTokenizer.from_pretrained")

tests/integration/diloco_test.py

Lines changed: 50 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import chex
2323
from flax.experimental import nnx
24-
from flax.training import train_state
2524
import jax
2625
import jax.numpy as jnp
2726
import jax.sharding
@@ -84,71 +83,36 @@ def test_diloco_training_simulation_with_mesh(self):
8483
tx = optax.sgd(learning_rate=0.1)
8584
rngs = nnx.Rngs(params=jax.random.key(seed=42))
8685
model = SimpleNNXModel(rngs=rngs)
87-
graphdef, params = nnx.split(model)
8886

89-
if test_config.pure_nnx:
90-
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
91-
# diloco_test_state expects a TrainStateNNX instance when pure_nnx is True.
92-
initial_test_state = TrainStateNNX(model, optimizer)
87+
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
88+
# diloco_test_state expects a TrainStateNNX instance.
89+
initial_test_state = TrainStateNNX(model, optimizer)
9390

94-
# For NNX, train_step needs to take the TrainStateNNX and mutate it
91+
# train_step takes the TrainStateNNX and mutates it.
9592

96-
def _test_train_step(state, batch, prng_key: diloco.PRNGKey):
97-
del prng_key
93+
def _test_train_step(state, batch, prng_key: diloco.PRNGKey):
94+
del prng_key
9895

99-
def loss_fn(model, batch):
100-
inputs, labels = batch
101-
logits = jax.vmap(model)(inputs)
102-
residual = logits - labels
103-
return jnp.mean(jnp.square(residual))
96+
def loss_fn(model, batch):
97+
inputs, labels = batch
98+
logits = jax.vmap(model)(inputs)
99+
residual = logits - labels
100+
return jnp.mean(jnp.square(residual))
104101

105-
loss, grads = nnx.value_and_grad(loss_fn)(state.model, batch)
106-
state.optimizer.update(state.model, grads)
107-
return state, loss
108-
109-
else:
110-
111-
def nnx_apply_fn(params, inputs):
112-
model_replica = nnx.merge(graphdef, params)
113-
return model_replica(inputs)
114-
115-
# 2. Vmap this new wrapper function
116-
vmapped_apply = jax.vmap(nnx_apply_fn, in_axes=(None, 0))
117-
118-
def _test_train_step(state: train_state.TrainState, batch, prng_key: diloco.PRNGKey):
119-
"""A simple MSE loss train step to enable numerics testing."""
120-
del prng_key
121-
122-
def loss_fn(params, batch):
123-
inputs, labels = batch
124-
logits = vmapped_apply(params, inputs)
125-
residual = logits - labels
126-
sq_residual = jnp.square(residual)
127-
msq_residual = jnp.mean(sq_residual)
128-
return msq_residual
129-
130-
loss, grad = jax.value_and_grad(loss_fn)(state.params, batch)
131-
return state.apply_gradients(grads=grad), loss
132-
133-
initial_test_state = train_state.TrainState.create(
134-
apply_fn=vmapped_apply,
135-
params=params,
136-
tx=tx,
137-
)
102+
loss, grads = nnx.value_and_grad(loss_fn)(state.model, batch)
103+
state.optimizer.update(state.model, grads)
104+
return state, loss
138105

139106
diloco_test_state, _ = diloco.build_diloco_state(test_config, lambda: initial_test_state)
140107
chex.assert_equal(diloco_test_state.step, 0)
141-
if test_config.pure_nnx:
142-
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
108+
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
143109

144-
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
145-
# We need to unwrap them if they do.
146-
diloco_params_pure = jax.tree_util.tree_map(
147-
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
148-
)
149-
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
150-
else:
151-
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
110+
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
111+
# We need to unwrap them if they do.
112+
diloco_params_pure = jax.tree_util.tree_map(
113+
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
114+
)
115+
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
152116

153117
diloco_train_step = diloco.build_diloco_train_step(test_config, _test_train_step)
154118
inputs = jnp.array(
@@ -196,17 +160,14 @@ def loss_fn(params, batch):
196160
chex.assert_equal(diloco_test_state.step, 1.0)
197161
chex.assert_equal(loss, 1.0)
198162
# Assert no updates to the global model yet (no synchronization)
199-
if test_config.pure_nnx:
200-
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
163+
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
201164

202-
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
203-
# We need to unwrap them if they do.
204-
diloco_params_pure = jax.tree_util.tree_map(
205-
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
206-
)
207-
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
208-
else:
209-
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
165+
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
166+
# We need to unwrap them if they do.
167+
diloco_params_pure = jax.tree_util.tree_map(
168+
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
169+
)
170+
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
210171

211172
# Run the second step (no synchronization).
212173
# Replica 0:
@@ -236,17 +197,14 @@ def loss_fn(params, batch):
236197
chex.assert_equal(diloco_test_state.step, 2.0)
237198
chex.assert_trees_all_close(loss, 0.65)
238199
# Assert no updates to the global model yet (no synchronization)
239-
if test_config.pure_nnx:
240-
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
200+
_, params_pure, _ = nnx.split(initial_test_state.model, nnx.Param, ...)
241201

242-
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
243-
# We need to unwrap them if they do.
244-
diloco_params_pure = jax.tree_util.tree_map(
245-
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
246-
)
247-
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
248-
else:
249-
chex.assert_trees_all_equal(diloco_test_state.params, initial_test_state.params)
202+
# diloco_test_state.params might contain nnx.Variables instead of pure arrays.
203+
# We need to unwrap them if they do.
204+
diloco_params_pure = jax.tree_util.tree_map(
205+
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
206+
)
207+
chex.assert_trees_all_equal(diloco_params_pure, params_pure.to_pure_dict())
250208

251209
# Run the third step, which synchronizes afterwards.
252210
# Replica 0:
@@ -281,31 +239,21 @@ def loss_fn(params, batch):
281239
chex.assert_trees_all_close(loss, 0.4481)
282240
# Assert that inner and outer parameters are all equal now that
283241
# synchronization has happened.
284-
if test_config.pure_nnx:
285-
_, inner_params, _ = nnx.split(diloco_test_state.inner_state.model, nnx.Param, ...)
286-
inner_params_pure = jax.tree_util.tree_map(
287-
lambda x: x.value if hasattr(x, "value") else x, inner_params.to_pure_dict()
288-
)
289-
diloco_params_pure_3 = jax.tree_util.tree_map(
290-
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
291-
)
292-
chex.assert_trees_all_equal(
293-
diloco_params_pure_3,
294-
jax.tree.map(lambda arr: arr[0, ...], inner_params_pure),
295-
)
296-
chex.assert_trees_all_equal(
297-
diloco_params_pure_3,
298-
jax.tree.map(lambda arr: arr[1, ...], inner_params_pure),
299-
)
300-
else:
301-
chex.assert_trees_all_equal(
302-
diloco_test_state.params,
303-
jax.tree.map(lambda arr: arr[0, ...], diloco_test_state.inner_state.params),
304-
)
305-
chex.assert_trees_all_equal(
306-
diloco_test_state.params,
307-
jax.tree.map(lambda arr: arr[1, ...], diloco_test_state.inner_state.params),
308-
)
242+
_, inner_params, _ = nnx.split(diloco_test_state.inner_state.model, nnx.Param, ...)
243+
inner_params_pure = jax.tree_util.tree_map(
244+
lambda x: x.value if hasattr(x, "value") else x, inner_params.to_pure_dict()
245+
)
246+
diloco_params_pure_3 = jax.tree_util.tree_map(
247+
lambda x: x.value if hasattr(x, "value") else x, diloco_test_state.params
248+
)
249+
chex.assert_trees_all_equal(
250+
diloco_params_pure_3,
251+
jax.tree.map(lambda arr: arr[0, ...], inner_params_pure),
252+
)
253+
chex.assert_trees_all_equal(
254+
diloco_params_pure_3,
255+
jax.tree.map(lambda arr: arr[1, ...], inner_params_pure),
256+
)
309257

310258
# Run the fourth step (no synchronization).
311259
# Replica 0:

tests/integration/hlo_diff_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,17 @@ def test_hlo_diff(self, test_id, config_file, overrides):
138138

139139
try:
140140
base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
141-
config_path = os.path.join(base_dir, config_file)
141+
# Compile via base.yml + model_name (the normal training path) so the config inherits
142+
# base.yml's logical_axis_rules and exercises the real NNX path. Loading the model yml
143+
# directly as the top-level config skips base.yml, leaving logical_axis_rules empty.
144+
base_config_path = os.path.join(base_dir, "src/maxtext/configs/base.yml")
145+
model_name = os.path.splitext(os.path.basename(config_file))[0]
142146

143147
# Arguments for train_compile
144148
test_args = [
145149
None,
146-
config_path,
150+
base_config_path,
151+
f"model_name={model_name}",
147152
"dataset_type=synthetic",
148153
"override_model_config=true",
149154
"compile_topology_num_slices=1",

0 commit comments

Comments
 (0)