Skip to content

Commit 7988f44

Browse files
Merge pull request #3619 from AI-Hypercomputer:engram_test_fix
PiperOrigin-RevId: 896922624
2 parents 4efab25 + c2dac5c commit 7988f44

2 files changed

Lines changed: 16 additions & 7 deletions

File tree

src/maxtext/optimizers/optimizers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import jax.numpy as jnp
2121

2222
import optax
23-
from maxtext.utils import max_logging
2423
from optax.contrib._muon import muon
2524
from maxtext.utils.muon_utils import get_muon_weight_dimension_numbers
2625

@@ -140,8 +139,7 @@ def do_update():
140139
return inner_opt.update(updates, state["inner_state"], params, **extra_args)
141140

142141
def skip_update():
143-
# use callback to work with jax.jit and jax.lax.cond for logging
144-
jax.debug.callback(lambda c: max_logging.warning(f"Step {c}: Optimizer step skipped due to spike."), count)
142+
# b/500923599: Investigate logging compatible with jax.jit, jax.lax.cond, and Pathway
145143
inner_updates = jax.tree_util.tree_map(jnp.zeros_like, updates)
146144
return inner_updates, state["inner_state"]
147145

tests/unit/deepseek_scan_engram_test.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,26 @@ class TestDeepSeekScanEngram(unittest.TestCase):
5353
"first_num_dense_layers=5",
5454
"base_num_decoder_layers=10",
5555
"num_decoder_layers=10",
56+
"base_emb_dim=64",
57+
"base_mlp_dim=64",
58+
"base_moe_mlp_dim=64",
59+
"base_num_query_heads=2",
60+
"base_num_kv_heads=2",
61+
"head_dim=32",
62+
"indexer_head_dim=32",
63+
"qk_nope_head_dim=32",
64+
"qk_rope_head_dim=16",
65+
"v_head_dim=32",
66+
"vocab_size=128",
5667
"mhc_expansion_rate=4",
5768
"attention=dot_product",
58-
"per_device_batch_size=2",
69+
"per_device_batch_size=1",
5970
"max_target_length=8",
6071
"max_prefill_predict_length=8",
6172
"enable_checkpointing=False",
6273
"engram_num_heads=1",
63-
"engram_head_dim=32",
64-
"engram_vocab_bases=[226240,226240]",
74+
"engram_head_dim=8",
75+
"engram_vocab_bases=[128,128]",
6576
"engram_max_ngram_size=3",
6677
"engram_kernel_size=4",
6778
"hf_access_token=dummy",
@@ -78,7 +89,7 @@ class MockTokenizer:
7889
pad_token_id = 0
7990

8091
def __len__(self):
81-
return 1000
92+
return 128
8293

8394
def __call__(self, x):
8495
return jnp.ones_like(x)

0 commit comments

Comments
 (0)