Skip to content

Commit 9071a8b

Browse files
Merge pull request #3860 from AI-Hypercomputer:shralex_test_2
PiperOrigin-RevId: 914947903
2 parents a01c414 + 1425f5b commit 9071a8b

26 files changed

Lines changed: 387 additions & 158 deletions

src/maxtext/utils/maxtext_utils.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,8 +1934,19 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):
19341934
return
19351935
max_logging.log("Tracing train_step to jaxpr...")
19361936

1937-
# We use the p_train_step (the JIT-decorated function)
1938-
p_train_jaxpr = jax.make_jaxpr(p_train_step)(*train_step_inputs)
1937+
# Trace the underlying un-jitted function via __wrapped__ to avoid heavy remote
1938+
# compilation/gRPC round-trips to the Pathways controller.
1939+
unwrapped_step = getattr(p_train_step, "__wrapped__", p_train_step)
1940+
1941+
def to_abstract(x):
1942+
if hasattr(x, "shape") and hasattr(x, "dtype"):
1943+
return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
1944+
return x
1945+
1946+
# Convert all input arguments recursively to purely local abstract ShapeDtypeStruct objects
1947+
# to completely bypass remote Array objects and proxy tracing overhead.
1948+
abstract_inputs = jax.tree.map(to_abstract, train_step_inputs)
1949+
p_train_jaxpr = jax.make_jaxpr(unwrapped_step)(*abstract_inputs)
19391950

19401951
local_filename = "train_step.jaxpr"
19411952
local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename)

tests/integration/aot_identical_test.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,13 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args):
110110
"steps=1",
111111
"enable_checkpointing=False",
112112
"base_num_decoder_layers=1",
113-
"max_target_length=512",
114-
"base_emb_dim=256",
115-
"base_mlp_dim=256",
113+
"max_target_length=32",
114+
"base_emb_dim=64",
115+
"base_mlp_dim=64",
116+
"base_num_query_heads=2",
117+
"base_num_kv_heads=2",
118+
"head_dim=16",
119+
"vocab_size=128",
116120
] + hlo_dump_args
117121
if extra_args:
118122
shared_args.extend(extra_args)
@@ -179,6 +183,14 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
179183
"enable_checkpointing=False",
180184
"dump_jaxpr=True",
181185
"dump_jaxpr_delete_local_after=False",
186+
"base_num_decoder_layers=1",
187+
"max_target_length=32",
188+
"base_emb_dim=64",
189+
"base_mlp_dim=64",
190+
"base_num_query_heads=2",
191+
"base_num_kv_heads=2",
192+
"head_dim=16",
193+
"vocab_size=128",
182194
]
183195
if extra_args:
184196
shared_args.extend(extra_args)
@@ -218,5 +230,15 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
218230
)
219231

220232
@pytest.mark.tpu_only
221-
def test_default_jaxpr_match(self):
222-
self.assert_compile_and_real_match_jaxpr("default_run")
233+
def test_default_jaxpr_match_mcjax(self):
234+
if os.getenv("JAX_PLATFORMS") == "proxy":
235+
pytest.skip("This is a McJAX test, skipping in Pathways environment.")
236+
self.assert_compile_and_real_match_jaxpr("default_run_mcjax")
237+
238+
@pytest.mark.tpu_only
239+
@pytest.mark.scheduled_only
240+
def test_default_jaxpr_match_pathways(self):
241+
# Currently this test is extremely slow (b/512065615).
242+
if os.getenv("JAX_PLATFORMS") != "proxy":
243+
pytest.skip("This is a Pathways test, skipping in McJAX environment.")
244+
self.assert_compile_and_real_match_jaxpr("default_run_pathways", "enable_single_controller=True")

tests/integration/checkpoint_resharding_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@
3535
def get_resharding_command(run_date, steps, metrics_file, base_output_directory, dataset_path, parallelism_args):
3636
"""Generates a command list for the resharding test run."""
3737
model_params = [
38-
"base_emb_dim=384",
39-
"base_num_query_heads=8",
40-
"base_num_kv_heads=8",
41-
"base_mlp_dim=192",
42-
"base_num_decoder_layers=8",
43-
"head_dim=128",
38+
"base_emb_dim=128",
39+
"base_num_query_heads=2",
40+
"base_num_kv_heads=2",
41+
"base_mlp_dim=128",
42+
"base_num_decoder_layers=2",
43+
"head_dim=64",
4444
]
4545

4646
return (

tests/integration/checkpointing_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ def get_checkpointing_command(run_date, hardware, steps, metrics_file, attention
5959
"""
6060
base_output_directory = get_test_base_output_directory()
6161
model_params = [
62-
"base_emb_dim=384",
63-
"base_num_query_heads=8",
64-
"base_num_kv_heads=8",
65-
"base_mlp_dim=192",
66-
"base_num_decoder_layers=8",
67-
"head_dim=128",
62+
"base_emb_dim=128",
63+
"base_num_query_heads=2",
64+
"base_num_kv_heads=2",
65+
"base_mlp_dim=128",
66+
"base_num_decoder_layers=1",
67+
"head_dim=64",
6868
]
6969
pathways_command = []
7070
if os.getenv("JAX_PLATFORMS") == "proxy":

tests/unit/deepseek_scan_engram_test.py renamed to tests/integration/deepseek_scan_engram_test.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __call__(self, x, model_mode):
4141
return jnp.ones((x.shape[0], x.shape[1], self.emb_dim))
4242

4343

44+
@pytest.mark.integration_test
4445
class TestDeepSeekScanEngram(unittest.TestCase):
4546
"""Test DeepSeek decoder block with scan_layers=True and engram_layers."""
4647

@@ -53,16 +54,16 @@ class TestDeepSeekScanEngram(unittest.TestCase):
5354
"first_num_dense_layers=5",
5455
"base_num_decoder_layers=10",
5556
"num_decoder_layers=10",
56-
"base_emb_dim=64",
57-
"base_mlp_dim=64",
58-
"base_moe_mlp_dim=64",
59-
"base_num_query_heads=2",
60-
"base_num_kv_heads=2",
61-
"head_dim=32",
62-
"indexer_head_dim=32",
63-
"qk_nope_head_dim=32",
64-
"qk_rope_head_dim=16",
65-
"v_head_dim=32",
57+
"base_emb_dim=8",
58+
"base_mlp_dim=8",
59+
"base_moe_mlp_dim=8",
60+
"base_num_query_heads=1",
61+
"base_num_kv_heads=1",
62+
"head_dim=4",
63+
"indexer_head_dim=4",
64+
"qk_nope_head_dim=4",
65+
"qk_rope_head_dim=4",
66+
"v_head_dim=4",
6667
"vocab_size=128",
6768
"mhc_expansion_rate=4",
6869
"attention=dot_product",
@@ -71,15 +72,24 @@ class TestDeepSeekScanEngram(unittest.TestCase):
7172
"max_prefill_predict_length=8",
7273
"enable_checkpointing=False",
7374
"engram_num_heads=1",
74-
"engram_head_dim=8",
75+
"engram_head_dim=4",
7576
"engram_vocab_bases=[128,128]",
7677
"engram_max_ngram_size=3",
7778
"engram_kernel_size=4",
79+
"num_experts=2",
80+
"num_experts_per_tok=1",
7881
"hf_access_token=dummy",
7982
"tokenizer_path=dummy",
8083
]
8184

82-
def _test_engram_pattern(self, mock_from_pretrained, engram_layers_str, expected_keys):
85+
def _test_engram_pattern(
86+
self,
87+
mock_from_pretrained,
88+
engram_layers_str,
89+
expected_keys,
90+
first_num_dense_layers=5,
91+
base_num_decoder_layers=10,
92+
):
8393
"""Helper method to test different engram layer patterns."""
8494

8595
# Setup mock tokenizer
@@ -106,7 +116,16 @@ def batch_decode(self, token_ids, *args, **kwargs):
106116
mock_from_pretrained.return_value = MockTokenizer()
107117

108118
config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")
109-
config = pyconfig.initialize([None, config_path] + self._COMMON_CONFIG + [f"engram_layers=[{engram_layers_str}]"])
119+
config = pyconfig.initialize(
120+
[None, config_path]
121+
+ self._COMMON_CONFIG
122+
+ [
123+
f"engram_layers=[{engram_layers_str}]",
124+
f"first_num_dense_layers={first_num_dense_layers}",
125+
f"base_num_decoder_layers={base_num_decoder_layers}",
126+
f"num_decoder_layers={base_num_decoder_layers}",
127+
]
128+
)
110129

111130
devices_array = maxtext_utils.create_device_mesh(config)
112131
mesh = Mesh(devices_array, config.mesh_axes)
@@ -126,7 +145,7 @@ def batch_decode(self, token_ids, *args, **kwargs):
126145

127146
shared_embedding = DummyEmbedding(emb_dim=config.emb_dim)
128147

129-
with mesh:
148+
with mesh, jax.disable_jit():
130149
variables = decoder.init(
131150
{"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1), "aqt": jax.random.PRNGKey(2)},
132151
shared_embedding=shared_embedding,
@@ -154,15 +173,16 @@ def test_decoder_init_engram_2_8(self, mock_from_pretrained):
154173
"""Test engram layers at indices 2 and 8."""
155174
self._test_engram_pattern(
156175
mock_from_pretrained,
157-
"2,8",
176+
"1,4",
158177
[
159-
"dense_layers_0_1",
160-
"dense_layers_engram_2",
161-
"dense_layers_3_4",
162-
"moe_layers_5_7",
163-
"moe_layers_engram_8",
164-
"moe_layers_9_9",
178+
"dense_layers_0_0",
179+
"dense_layers_engram_1",
180+
"dense_layers_2_2",
181+
"moe_layers_3_3",
182+
"moe_layers_engram_4",
165183
],
184+
first_num_dense_layers=3,
185+
base_num_decoder_layers=5,
166186
)
167187

168188
@pytest.mark.tpu_only
@@ -171,8 +191,10 @@ def test_decoder_init_engram_0_5(self, mock_from_pretrained):
171191
"""Test engram layers at indices 0 and 5 - first engram layer of block."""
172192
self._test_engram_pattern(
173193
mock_from_pretrained,
174-
"0,5",
175-
["dense_layers_engram_0", "dense_layers_1_4", "moe_layers_engram_5", "moe_layers_6_9"],
194+
"0,1",
195+
["dense_layers_engram_0", "moe_layers_engram_1"],
196+
first_num_dense_layers=1,
197+
base_num_decoder_layers=2,
176198
)
177199

178200
@pytest.mark.tpu_only
@@ -181,6 +203,8 @@ def test_decoder_init_engram_4_9(self, mock_from_pretrained):
181203
"""Test engram layers at indices 4 and 9 - last engram layer of block."""
182204
self._test_engram_pattern(
183205
mock_from_pretrained,
184-
"4,9",
185-
["dense_layers_0_3", "dense_layers_engram_4", "moe_layers_5_8", "moe_layers_engram_9"],
206+
"1,3",
207+
["dense_layers_0_0", "dense_layers_engram_1", "moe_layers_2_2", "moe_layers_engram_3"],
208+
first_num_dense_layers=2,
209+
base_num_decoder_layers=4,
186210
)

tests/integration/determinism_test.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import pytest
2727

2828
from maxtext.trainers.pre_train.train import main as train_main
29-
from tests.utils.test_helpers import get_test_config_path
29+
from tests.utils.test_helpers import get_test_config_path, get_test_base_output_directory, get_test_dataset_path
3030

3131
pytestmark = pytest.mark.integration_test
3232

@@ -52,13 +52,17 @@ def test_determinism(self):
5252
common_config = [
5353
None,
5454
get_test_config_path(),
55-
"steps=5",
55+
"steps=2",
5656
"enable_checkpointing=False",
5757
"enable_data_shuffling=True",
5858
"enable_dropout=False",
59-
"base_output_directory=gs://runner-maxtext-logs",
60-
"dataset_path=gs://maxtext-dataset",
59+
f"base_output_directory={get_test_base_output_directory()}",
60+
f"dataset_path={get_test_dataset_path()}",
6161
"skip_jax_distributed_system=True",
62+
"base_emb_dim=128",
63+
"base_mlp_dim=128",
64+
"base_num_decoder_layers=1",
65+
"head_dim=64",
6266
]
6367
train_1_config = common_config + [
6468
f"run_name={run_name}_1",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __call__(self, x):
5151
return self.dense(x)
5252

5353

54+
@pytest.mark.integration_test
5455
class DiLoCoTest(unittest.TestCase):
5556

5657
@pytest.mark.tpu_only
@@ -284,6 +285,13 @@ def test_diloco_qwen3_moe_two_slices(self):
284285
"dcn_diloco_parallelism=2",
285286
"enable_diloco=true",
286287
"model_name=qwen3-30b-a3b",
288+
"override_model_config=True",
289+
"base_emb_dim=32",
290+
"base_num_decoder_layers=1",
291+
"base_mlp_dim=64",
292+
"base_num_query_heads=4",
293+
"base_num_kv_heads=4",
294+
"head_dim=8",
287295
)
288296
)
289297

@@ -302,5 +310,12 @@ def test_diloco_two_slices(self):
302310
"dcn_diloco_parallelism=2",
303311
"enable_diloco=true",
304312
"model_name=gemma2-2b",
313+
"override_model_config=True",
314+
"base_emb_dim=32",
315+
"base_num_decoder_layers=1",
316+
"base_mlp_dim=64",
317+
"base_num_query_heads=1",
318+
"base_num_kv_heads=1",
319+
"head_dim=4",
305320
)
306321
)

tests/integration/generate_param_only_checkpoint_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@
3131
def get_model_params(quantization):
3232
return [
3333
f"quantization={quantization}",
34-
"base_emb_dim=384",
35-
"base_num_query_heads=8",
36-
"base_num_kv_heads=8",
37-
"base_mlp_dim=192",
38-
"base_num_decoder_layers=8",
39-
"head_dim=128",
34+
"base_emb_dim=128",
35+
"base_num_query_heads=2",
36+
"base_num_kv_heads=2",
37+
"base_mlp_dim=128",
38+
"base_num_decoder_layers=1",
39+
"head_dim=64",
4040
]
4141

4242

@@ -69,7 +69,7 @@ def run_e2e_test_flow(hardware, model_config, attention_type="autoselected", sta
6969
steps=1,
7070
metrics_file="run_metrics.txt",
7171
attention_type=attention_type,
72-
dataset_type="tfds",
72+
dataset_type="synthetic",
7373
dataset_path=dataset_path,
7474
)
7575
)

tests/integration/gradient_accumulation_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,15 @@ def test_sft_grad_accumulate_same_loss(self):
154154
[
155155
None,
156156
get_test_config_path(),
157-
"base_output_directory=gs://runner-maxtext-logs",
158-
"dataset_path=gs://maxtext-dataset",
157+
f"base_output_directory={self.base_output_directory}",
158+
f"dataset_path={self.dataset_path}",
159+
"dataset_type=synthetic",
160+
"max_target_length=128",
159161
"gradient_clipping_threshold=0", # Ensures we are testing raw scales of gradients (clipping off).
160162
"enable_checkpointing=False",
161163
"enable_goodput_recording=False",
162-
"base_emb_dim=256",
163-
"base_num_decoder_layers=4",
164+
"base_emb_dim=128",
165+
"base_num_decoder_layers=1",
164166
rf"tokenizer_path={os.path.join(MAXTEXT_ASSETS_ROOT, 'tokenizers', 'tokenizer.llama2')}",
165167
"steps=3",
166168
"gradient_accumulation_steps=2",
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
pytestmark = [pytest.mark.external_serving]
3737

3838

39+
@pytest.mark.integration_test
3940
class MaxEngineTest(unittest.TestCase):
4041
"""Tests for MaxEngine."""
4142

@@ -55,7 +56,7 @@ def init_pyconfig(self, **kwargs):
5556
"base_num_decoder_layers": 2,
5657
"attention": "dot_product",
5758
"max_target_length": 16,
58-
"base_emb_dim": 256,
59+
"base_emb_dim": 32,
5960
"base_num_query_heads": 2,
6061
"base_num_kv_heads": 2,
6162
"max_prefill_predict_length": 4,

0 commit comments

Comments
 (0)