Skip to content

Commit f0eab7b

Browse files
Merge pull request #3803 from AI-Hypercomputer:shralex_warnings_2
PiperOrigin-RevId: 910424858
2 parents c8e277d + 3138907 commit f0eab7b

22 files changed

Lines changed: 238 additions & 227 deletions

.github/workflows/run_jupyter_notebooks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ jobs:
9494
PAPERMILL_EXE=".venv/bin/papermill"
9595
source .venv/bin/activate
9696
fi
97-
export PYTHONPATH="${pwd}/src${PYTHONPATH:+:${PYTHONPATH}}"
97+
export PYTHONPATH="${PWD}/src${PYTHONPATH:+:${PYTHONPATH}}"
9898
9999
export MAXTEXT_REPO_ROOT=$(pwd)
100100
export MAXTEXT_PKG_DIR=$(pwd)/src/maxtext

.github/workflows/run_tests_against_package.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ jobs:
138138
PYTHON_EXE=".venv/bin/python3"
139139
# Ensure pytest-cov is available and enable coverage flags
140140
uv pip install pytest-cov
141-
PYTEST_COV_ARGS="--cov=MaxText --cov=maxtext --cov-report=xml --cov-report=term"
141+
PYTEST_COV_ARGS="--cov=maxtext --cov-report=xml --cov-report=term"
142142
fi
143143
export PYTHONPATH="${PWD}/src${PYTHONPATH:+:${PYTHONPATH}}"
144144
@@ -208,7 +208,7 @@ jobs:
208208
continue-on-error: true
209209
with:
210210
token: ${{ secrets.CODECOV_TOKEN }}
211-
file: ./coverage.xml
211+
files: ./coverage.xml
212212
# If scheduled, upload to scheduled flag only. If PR, upload to regular flag only.
213213
flags: ${{ inputs.is_scheduled_run == 'true' && 'scheduled' || 'regular' }}
214214
verbose: true

src/maxtext/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ class HardwareAndMesh(BaseModel):
848848
description="Strategy for context parallelism ('all_gather' or 'ring').",
849849
)
850850
context_parallel_reorder_strategy: ReorderStrategy = Field(
851-
"auto",
851+
ReorderStrategy.AUTO,
852852
description="Reorder strategy for load-balanced context parallelism.",
853853
)
854854
custom_mesh: str = Field("", description="Available options: ['hybrid_ring_64x4', 'hybrid_ring_32x8']")

src/maxtext/inference/kvcache.py

Lines changed: 105 additions & 74 deletions
Large diffs are not rendered by default.

src/maxtext/inference/paged_attention.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,22 @@ def __init__(
170170

171171
self.key_pages = nnx.Cache(
172172
jnp.zeros(self.kv_pages_shape, dtype=self.dtype),
173-
sharding=self.kv_pages_axis_names,
173+
out_sharding=self.kv_pages_axis_names,
174174
)
175175
self.value_pages = nnx.Cache(
176176
jnp.zeros(self.kv_pages_shape, dtype=self.dtype),
177-
sharding=self.kv_pages_axis_names,
177+
out_sharding=self.kv_pages_axis_names,
178178
)
179179

180180
def _maybe_materialize_cache(self, cache: nnx.Cache) -> nnx.Cache:
181181
"""Materializes the cache if it's currently a ShapeDtypeStruct."""
182-
if isinstance(cache.value, jax.ShapeDtypeStruct):
182+
if isinstance(cache.get_value(), jax.ShapeDtypeStruct):
183183
# This is needed because the Linen bridge lazily creates this state. We
184184
# need to ensure the cache state is accessible at runtime.
185185
# TODO: Delete this function when the to_linen bridge is no longer needed.
186186
return nnx.Cache(
187187
jnp.zeros(self.kv_pages_shape, dtype=self.dtype),
188-
sharding=cache.sharding,
188+
out_sharding=cache.get_metadata("out_sharding"),
189189
)
190190
return cache
191191

@@ -204,8 +204,8 @@ def get_kv_pages(self):
204204
self.key_pages = self._maybe_materialize_cache(self.key_pages)
205205
self.value_pages = self._maybe_materialize_cache(self.value_pages)
206206

207-
self.key_pages.value = nn.with_logical_constraint(self.key_pages.value, self.kv_pages_axis_names)
208-
self.value_pages.value = nn.with_logical_constraint(self.value_pages.value, self.kv_pages_axis_names)
207+
self.key_pages.set_value(nn.with_logical_constraint(self.key_pages.get_value(), self.kv_pages_axis_names))
208+
self.value_pages.set_value(nn.with_logical_constraint(self.value_pages.get_value(), self.kv_pages_axis_names))
209209
return self.key_pages, self.value_pages
210210

211211
def pad_qkv(self, *qkv):
@@ -264,9 +264,9 @@ def paged_attention_v2_prefill(
264264
is the batch_size is only 1
265265
"""
266266
assert query.shape[0] == 1 # ensure the batch size is 0
267-
# shape of key_pages_cache.value is [num_kv_heads, num_pages, tokens_per_page, head_dim]
268-
k_p = jnp.permute_dims(key_pages_cache.value, (1, 2, 0, 3))
269-
v_p = jnp.permute_dims(value_pages_cache.value, (1, 2, 0, 3))
267+
# shape of key_pages_cache.get_value() is [num_kv_heads, num_pages, tokens_per_page, head_dim]
268+
k_p = jnp.permute_dims(key_pages_cache.get_value(), (1, 2, 0, 3))
269+
v_p = jnp.permute_dims(value_pages_cache.get_value(), (1, 2, 0, 3))
270270
c_q_l = jnp.array([0, page_state.sequence_lengths[0]]) # [0, prefill_true_length]
271271
num_seqs = jnp.array([1])
272272
query = query[0] # [batch_size, max_num_tokens, num_kv_heads, head_dim] to [max_num_tokens, num_kv_heads, head_dim]
@@ -294,8 +294,8 @@ def paged_attention_v2_decode(
294294
"""Apply ragged input Paged Attention in decode only."""
295295
batch_size = query.shape[0]
296296
query = jnp.squeeze(query, axis=1) # [batch_size, seq_len, n_kv_head, head_dim] to [batch_size, n_kv_head, head_dim]
297-
k_p = jnp.permute_dims(key_pages_cache.value, (1, 2, 0, 3))
298-
v_p = jnp.permute_dims(value_pages_cache.value, (1, 2, 0, 3))
297+
k_p = jnp.permute_dims(key_pages_cache.get_value(), (1, 2, 0, 3))
298+
v_p = jnp.permute_dims(value_pages_cache.get_value(), (1, 2, 0, 3))
299299
c_q_l = jnp.arange(batch_size + 1) # one token per sequence
300300
num_seqs = jnp.array([batch_size]) # real number of requests, set it to batch_size
301301
result = paged_attention_kernel_v2.ragged_paged_attention(
@@ -352,8 +352,8 @@ def wrap_paged_attention(q, k_pages, v_pages, lengths, page_indices, pages_per_c
352352

353353
return wrap_paged_attention(
354354
query,
355-
key_pages_cache.value,
356-
value_pages_cache.value,
355+
key_pages_cache.get_value(),
356+
value_pages_cache.get_value(),
357357
page_state.sequence_lengths,
358358
page_state.page_map,
359359
self.pages_per_compute_block,
@@ -441,12 +441,12 @@ def update_prefill_step_pages(
441441
), f"prefill_step key/value should have the same shape, but getting {key.shape=} and {value.shape=} instead"
442442
batch_size, seq_len, n_kv_head, head_dim = key.shape
443443
assert seq_len % self.tokens_per_page == 0, f"seq_length {seq_len} and tokens_per_page {self.tokens_per_page}"
444-
assert key_pages_cache.value.shape == value_pages_cache.value.shape, (
444+
assert key_pages_cache.get_value().shape == value_pages_cache.get_value().shape, (
445445
f"prefill_step key/value_pages_cache should have the same shape, but "
446446
f"getting {key_pages_cache.shape=} and {value_pages_cache.shape=} instead"
447447
)
448448

449-
v_n_kv, _, v_p, v_d = key_pages_cache.value.shape
449+
v_n_kv, _, v_p, v_d = key_pages_cache.get_value().shape
450450
assert v_n_kv == n_kv_head, f"{v_n_kv=} {n_kv_head=}"
451451
assert v_p == self.tokens_per_page, f"{v_p=} {self.tokens_per_page=}"
452452
assert v_d == head_dim, f"{v_d=} {head_dim=}"
@@ -485,13 +485,13 @@ def update_prefill_step_pages(
485485
),
486486
)
487487

488-
key_pages_cache.value = nn.with_logical_constraint(key, self.kv_pages_axis_names)
489-
value_pages_cache.value = nn.with_logical_constraint(value, self.kv_pages_axis_names)
488+
key_pages_cache.set_value(nn.with_logical_constraint(key, self.kv_pages_axis_names))
489+
value_pages_cache.set_value(nn.with_logical_constraint(value, self.kv_pages_axis_names))
490490

491491
def update_decode_step_pages(self, key_pages_cache, value_pages_cache, key, value, page_state):
492492
"""Update decode-step pages"""
493-
key_pages = key_pages_cache.value
494-
value_pages = value_pages_cache.value
493+
key_pages = key_pages_cache.get_value()
494+
value_pages = value_pages_cache.get_value()
495495

496496
batch_size, _, kv_heads, head_dim = key.shape
497497
kv_heads, _, _, head_dim = key_pages.shape
@@ -511,6 +511,6 @@ def update_decode_step_pages(self, key_pages_cache, value_pages_cache, key, valu
511511
key_pages_updated = key_pages.at[kv_indices, broadcast_pages, broadcast_pos].set(new_key)
512512
value_pages_updated = value_pages.at[kv_indices, broadcast_pages, broadcast_pos].set(new_value)
513513

514-
key_pages_cache.value = key_pages_updated
515-
value_pages_cache.value = value_pages_updated
514+
key_pages_cache.set_value(key_pages_updated)
515+
value_pages_cache.set_value(value_pages_updated)
516516
return key_pages_cache, value_pages_cache

src/maxtext/layers/attention_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,7 @@ def __call__(
12001200
sparse_loss=self.config.indexer_sparse_training,
12011201
scaling_factor=self.config.indexer_loss_scaling_factor,
12021202
)
1203-
self.sow(nnx.Intermediate, "indexer_loss", indexer_loss)
1203+
self.indexer_loss = nnx.Intermediate(indexer_loss)
12041204

12051205
# Check if we need QK Clip stats
12061206
use_qk_clip = self.model_mode == MODEL_MODE_TRAIN and self.config.use_qk_clip

src/maxtext/layers/attention_op.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ def apply_attention(
902902

903903
local_out, local_max, local_sum = impl(query, key, value, lengths, self.ragged_block_size)
904904
if record_max_logits:
905-
self.sow("intermediates", "max_logits", local_max)
905+
self.max_logits = nnx.Intermediate(local_max)
906906
return local_out, local_max, local_sum
907907

908908
# 'vllm_rpa' uses the same dot-attention wrapper but routes to the vLLM
@@ -951,7 +951,7 @@ def apply_attention(
951951
record_max_logits=record_max_logits,
952952
)
953953
if max_logits is not None:
954-
self.sow("intermediates", "max_logits", max_logits)
954+
self.max_logits = nnx.Intermediate(max_logits)
955955
return out, None, None
956956

957957
else:
@@ -1861,7 +1861,7 @@ def apply_attention_dot(
18611861
max_logits_per_group = jnp.max(attn_weights, axis=(-2, -1))
18621862
b, n_kv, g = max_logits_per_group.shape
18631863
max_logits = max_logits_per_group.reshape(b, n_kv * g)
1864-
self.sow("intermediates", "max_logits", max_logits)
1864+
self.max_logits = nnx.Intermediate(max_logits)
18651865

18661866
return self.compute_local_attention(attn_weights, value, q_seq_len, model_mode, wv_product_einsum, sinks)
18671867

src/maxtext/layers/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def __call__(self, inputs: Array, model_mode: str = MODEL_MODE_TRAIN) -> Array:
152152
raise ValueError("Input type must be an integer or unsigned integer.")
153153

154154
embedding = jnp.asarray(
155-
_maybe_move_embedding_to_device(self.embedding.value, self.config),
155+
_maybe_move_embedding_to_device(self.embedding.get_value(), self.config),
156156
self.dtype,
157157
)
158158

@@ -196,7 +196,7 @@ def attend(self, query: Array, out_sharding: NamedSharding | None = None) -> Arr
196196
Commonly used for weight-sharing between embeddings and logit transform
197197
in NLP models.
198198
"""
199-
embedding = self.embedding.value
199+
embedding = self.embedding.get_value()
200200
attend_dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
201201
return attend_on_embedding(query, embedding, attend_dtype, self.config, out_sharding)
202202

src/maxtext/layers/engram.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616
DeepSeek-AI, `Conditional Memory via Scalable Lookup: A New Axis of Sparsity for Large Language Models
1717
<https://arxiv.org/pdf/2601.07372>`_, 2026
18-
18+
1919
Reference implementation: https://github.com/deepseek-ai/Engram/blob/main/engram_demo_v1.py
2020
"""
2121

@@ -53,7 +53,7 @@ class CompressedTokenizer:
5353
def __init__(self, tokenizer: HFTokenizer):
5454
normalizer = self._build_normalizer()
5555
self.lookup_table_np, self.num_new_token = self._build_lookup_table(tokenizer, normalizer)
56-
self.lookup_table = jnp.array(self.lookup_table_np, dtype=jnp.int64)
56+
self.lookup_table = jnp.array(self.lookup_table_np, dtype=jnp.int32)
5757

5858
def __len__(self) -> int:
5959
return self.num_new_token
@@ -125,7 +125,7 @@ def __call__(self, input_ids) -> Array:
125125
"""
126126
Maps original token IDs to compressed IDs.
127127
"""
128-
input_ids = jnp.asarray(input_ids, dtype=jnp.int64)
128+
input_ids = jnp.asarray(input_ids, dtype=jnp.int32)
129129

130130
# Map negative IDs to 0 for lookup, then mask output back.
131131
safe_ids = jnp.where(input_ids < 0, 0, input_ids)
@@ -187,7 +187,7 @@ def __init__(
187187
# Pre-calculate odd multipliers for hashing: {layer_id: multipliers}
188188
# Store as JAX arrays
189189
self.layer_multipliers = {
190-
k: jnp.array(v, dtype=jnp.int64) for k, v in self._calculate_multipliers_across_layers(seed).items()
190+
k: jnp.array(v, dtype=jnp.int32) for k, v in self._calculate_multipliers_across_layers(seed).items()
191191
}
192192

193193
# Pre-calculate unique prime vocab sizes for every head
@@ -201,9 +201,9 @@ def _calculate_multipliers_across_layers(self, seed: int) -> dict[int, np.ndarra
201201
Returns:
202202
A dictionary mapping layer_id to a list of `max_ngram_size` multipliers.
203203
"""
204-
# Pre-calculate bounds for random generation
205-
max_long = np.iinfo(np.int64).max
206-
m_max = int(max_long // self.tokenizer_vocab_size)
204+
# Pre-calculate bounds for random generation using int32 to avoid overflow
205+
max_int = np.iinfo(np.int32).max
206+
m_max = int(max_int // self.tokenizer_vocab_size)
207207
half_bound = max(1, m_max // 2)
208208
# Hard-code prime number to align with reference
209209
LAYER_PRIME_OFFSET = 10007
@@ -214,7 +214,7 @@ def _calculate_multipliers_across_layers(self, seed: int) -> dict[int, np.ndarra
214214
layer_seed = int(seed + LAYER_PRIME_OFFSET * int(layer_id))
215215
np_rng = np.random.default_rng(layer_seed)
216216
# Generate random odd integers
217-
random_value = np_rng.integers(low=0, high=half_bound, size=(self.max_ngram_size,), dtype=np.int64)
217+
random_value = np_rng.integers(low=0, high=half_bound, size=(self.max_ngram_size,), dtype=np.int32)
218218
multipliers = random_value * 2 + 1
219219
layer_multipliers[layer_id] = multipliers
220220
return layer_multipliers
@@ -272,7 +272,7 @@ def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array:
272272
Returns:
273273
hash_ids: [B, S, H_total] where H_total = H * num_ngram_orders
274274
"""
275-
x = jnp.asarray(compressed_ids, dtype=jnp.int64)
275+
x = jnp.asarray(compressed_ids, dtype=jnp.int32)
276276
B, _ = x.shape
277277

278278
# 1. Create Sliding Windows via Shifting
@@ -282,7 +282,7 @@ def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array:
282282
shifted_inputs.append(x)
283283
else:
284284
# Pre-allocate full array with PAD_ID
285-
padding = jnp.full((B, k), self.pad_id, dtype=jnp.int64)
285+
padding = jnp.full((B, k), self.pad_id, dtype=jnp.int32)
286286
# Fast memory copy, slicing and assignment
287287
# e.g., k=1, [PAD, The, cat]
288288
# k=2, [PAD, PAD, The]
@@ -309,7 +309,7 @@ def _get_ngram_hashes(self, compressed_ids: Array, layer_id: int) -> Array:
309309

310310
# Retrieve prime vocab sizes for all heads of this n-gram order
311311
vocab_sizes_for_this_gram = vocab_sizes[n - 2]
312-
mods = jnp.array(vocab_sizes_for_this_gram, dtype=jnp.int64)
312+
mods = jnp.array(vocab_sizes_for_this_gram, dtype=jnp.int32)
313313

314314
# Broadcast Modulo: Map hash to valid table indices
315315
# [B, S, 1] % [H] -> [B, S, H]

src/maxtext/layers/initializers.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ def init_fn(key, shape, dtype, in_axis, out_axis):
6060
return init_fn
6161

6262

63-
def variable_to_logically_partitioned(variable: nnx.VariableState):
63+
def variable_to_logically_partitioned(variable: nnx.Variable):
6464
"""Wraps an NNX variable's value in `nn.LogicallyPartitioned`.
6565
66-
This function inspects the metadata of an `nnx.VariableState` object. If
66+
This function inspects the metadata of an `nnx.Variable` object. If
6767
sharding information ('out_sharding', 'sharding' or 'sharding_names') is
6868
present, it wraps the variable's value in `nn.LogicallyPartitioned` to apply
6969
the specified sharding constraints.
@@ -73,16 +73,17 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
7373
wrapping.
7474
7575
Args:
76-
variable: The `nnx.VariableState` object to process.
76+
variable: The `nnx.Variable` object to process.
7777
7878
Returns:
7979
The variable's value, potentially wrapped in `nn.LogicallyPartitioned`.
8080
"""
81-
if isinstance(variable.value, aqt_tensor.QTensor):
82-
return variable.value
81+
val = variable.get_value()
82+
if isinstance(val, aqt_tensor.QTensor):
83+
return val
8384

8485
if variable.type.__name__ == "_overwrite_with_gradient":
85-
return variable.value
86+
return val
8687

8788
metadata = variable.get_metadata()
8889
out_sharding = None
@@ -95,10 +96,10 @@ def variable_to_logically_partitioned(variable: nnx.VariableState):
9596

9697
if out_sharding is not None:
9798
return nn.LogicallyPartitioned( # type: ignore[wrong-keyword-args]
98-
variable.value,
99+
val,
99100
out_sharding, # type: ignore[arg-type]
100101
mesh=metadata.get("mesh"),
101102
rules=metadata.get("rules"),
102103
)
103104
else:
104-
return variable.value
105+
return val

0 commit comments

Comments
 (0)