Skip to content

Commit faef777

Browse files
darisoyGoogle-ML-Automation
authored andcommitted
PR #3948: Scrub Paged Attention Serving Kernels and Obsolete Allocators
Imported from GitHub PR #3948 # Description This PR completely removes and scrubs the obsolete, TPU-specific, and vulnerable custom JAX Pallas Paged Attention serving kernels and configurations from MaxText. During security reviews (such as `b/510375529`), the physical page manager allocator was flagged for potential cross-tenant HBM memory leaks. Following team alignment, it was confirmed that no production multimodal or Reinforcement Learning (RL) serving pipelines use `attention="paged"`. To ensure model architecture definitions (Gemma, Llama, Mistral, Qwen, DeepSeek) and attention layer files do not crash on missing imports during this transition, this PR introduces lightweight, zero-overhead transitional compatibility shims in place of the deleted allocator and operator. ### Details & Implementation: * **Paged Attention Deletion**: Purged custom Pallas kernels (`paged_attention_kernel_v2.py`) and page manager tests (`page_manager_test.py`). Scrubbed `pagedattn_` variables from `base.yml`, `types.py`, and config validators inside `pyconfig_deprecated.py`. Scrubbed page allocators and layout bindings inside `maxengine.py`. * **Transitional Compatibility Shims**: Created dummy shims for `src/maxtext/inference/page_manager.py` (`PageState`, `PageManager`) and `src/maxtext/inference/paged_attention.py` (`PagedAttentionOp`) to allow all model layers to compile successfully with zero code mutations. # Tests CI integration tests # Checklist Before submitting this PR, please make sure (put X in square brackets): - [X] I have performed a self-review of my code. For an optional AI review, add the `gemini-review` label. - [X] I have necessary comments in my code, particularly in hard-to-understand areas. - [X] I have run end-to-end tests tests and provided workload links above if applicable. - [X] I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in [our documentation](https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files). Copybara import of the project: -- 4c9c3b2 by Jetski <jetski@google.com>: Fix MaxEngine chunked prefill JAX dynamic shape trace alignment, serving batch capacity mismatch, and purge obsolete paged attention serving kernels Merging this change closes #3948 COPYBARA_INTEGRATE_REVIEW=#3948 from AI-Hypercomputer:fix-oom-aot-clean 4c9c3b2 PiperOrigin-RevId: 922761322
1 parent 7c7f628 commit faef777

28 files changed

Lines changed: 33 additions & 2631 deletions

src/maxtext/configs/base.yml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,18 +1068,6 @@ context_parallel_load_balance: true
10681068
context_parallel_strategy: "all_gather" # "all_gather" or "ring"
10691069
context_parallel_reorder_strategy: "auto" # "auto", "dual_chunk_swap", or "striped"
10701070

1071-
### Paged Attention ###
1072-
# These settings take effect only when `attention=paged`.
1073-
# They should be adjusted based on the available HBM and model config.
1074-
# Note: one page group corresponds to one request/slot
1075-
pagedattn_num_pages: 64 # total number of pages to allocate
1076-
pagedattn_tokens_per_page: 32 # number of tokens each page can hold
1077-
pagedattn_pages_per_compute_block: 4 # number of pages processed together in pallas kernels
1078-
pagedattn_max_pages_per_group: -1 # defaults to number of pages needed to reach max_target_length
1079-
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
1080-
# TPUs, the head_dim is padded to the nearest multiple of 128.
1081-
pagedattn_head_dim_alignment: 128
1082-
10831071

10841072
# Chunked Prefill Parameters
10851073
prefill_chunk_size: 256

src/maxtext/configs/pyconfig_deprecated.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def validate_attention_kernel(s: str) -> None:
105105
"flash",
106106
"cudnn_flash_te",
107107
"cudnn_flash_jax",
108-
"paged",
109108
"vllm_rpa",
110109
)
111110
if s not in valid_attention_kernels: # currently supported attention
@@ -119,7 +118,7 @@ def validate_attention_type(s: str) -> None:
119118

120119

121120
def validate_moba_attention(moba, attention) -> None:
122-
if moba and attention in ("autoselected", "flash", "cudnn_flash_te", "cudnn_flash_jax", "paged"):
121+
if moba and attention in ("autoselected", "flash", "cudnn_flash_te", "cudnn_flash_jax"):
123122
raise ValueError("MoBA is only supported dot_product attention")
124123

125124

@@ -816,11 +815,6 @@ def user_init(raw_keys):
816815
)
817816
raw_keys["shardy"] = False
818817

819-
if raw_keys["pagedattn_max_pages_per_group"] <= 0:
820-
raw_keys["pagedattn_max_pages_per_group"] = (
821-
raw_keys["max_target_length"] + raw_keys["pagedattn_tokens_per_page"] - 1
822-
) // raw_keys["pagedattn_tokens_per_page"]
823-
824818
raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys)
825819
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)
826820
raw_keys["context_parallel_size"] = get_context_parallel_size(raw_keys)

src/maxtext/configs/types.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -698,18 +698,6 @@ class SplashAttention(BaseModel):
698698
use_splash_scheduler: bool = Field(False, description="Use experimental splash attention scheduler.")
699699

700700

701-
class PagedAttention(BaseModel):
702-
"""Tunable parameters for Paged Attention kernels."""
703-
704-
pagedattn_num_pages: int = Field(64, description="Total number of pages to allocate for paged attention.")
705-
pagedattn_tokens_per_page: int = Field(32, description="Number of tokens each page can hold.")
706-
pagedattn_pages_per_compute_block: int = Field(4, description="Number of pages processed together in pallas kernels.")
707-
pagedattn_max_pages_per_group: int = Field(-1, description="Max pages per request; -1 defaults to max_target_length.")
708-
# Alignment of head_dim to the nearest multiple of this value, set to 0 to disable alignment. On
709-
# TPUs, the head_dim is padded to the nearest multiple of 128.
710-
pagedattn_head_dim_alignment: int = Field(128, description="Alignment of head_dim to the nearest multiple.")
711-
712-
713701
class MoEGeneral(BaseModel):
714702
"""General configuration for Mixture of Experts (MoE) layers."""
715703

@@ -2260,7 +2248,6 @@ class MaxTextConfig(
22602248
AttentionIndexer,
22612249
Llama4Attention,
22622250
SplashAttention,
2263-
PagedAttention,
22642251
# Mixture of Experts
22652252
MoEGeneral,
22662253
MoEKernels,
@@ -3208,5 +3195,4 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
32083195
f"For qwen3_custom_moe, moe_expert_input_dim ({self.moe_expert_input_dim}) "
32093196
f"must be equal to attention_output_dim ({self.attention_output_dim})"
32103197
)
3211-
32123198
return self

src/maxtext/inference/maxengine/maxengine.py

Lines changed: 11 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from maxtext.models import models
4242
from maxtext.layers import quantizations
4343
from maxtext.inference import inference_utils
44-
from maxtext.inference.page_manager import PageManager, PageState
4544
from maxtext.multimodal import processor as mm_processor
4645
from maxtext.utils import lora_utils
4746
from maxtext.utils import max_utils
@@ -129,13 +128,6 @@ def __init__(self, config: Any, devices: Any | None = None):
129128
self.param_layouts = None
130129
self.rng = None
131130

132-
# Initialize page manager and page state
133-
self.page_manager = None
134-
self.page_state = None
135-
if self.config.attention == "paged":
136-
self.page_manager = PageManager(self.config)
137-
self.page_state = self.page_manager.get_initial_page_state()
138-
139131
def print_stats(self, label: str):
140132
max_utils.print_mem_stats(label)
141133
max_utils.print_cpu_ram_stats(label)
@@ -250,7 +242,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
250242
)
251243

252244
self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations(
253-
self.model, self.config, rng2, self._mesh, self.page_state
245+
self.model, self.config, rng2, self._mesh
254246
)
255247
self.prefill_kv_cache_shardings = jax.tree_util.tree_map(
256248
lambda x: jax.sharding.NamedSharding(self._mesh, x),
@@ -265,9 +257,7 @@ def load_params(self, *args, params=None, rng: PRNGKeyType | None = None, **kwar
265257
)
266258
self.prefill_kv_cache_shardings = self.prefill_kv_cache_shardings["decoder"]["layers_0"]
267259

268-
self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations(
269-
self.model, self.config, rng2, self._mesh, self.page_state
270-
)
260+
self.kv_cache_annotations = maxtext_utils.get_kv_cache_annotations(self.model, self.config, rng2, self._mesh)
271261
self.kv_cache_shardings = jax.tree_util.tree_map(
272262
lambda x: jax.sharding.NamedSharding(self._mesh, x),
273263
self.kv_cache_annotations,
@@ -424,7 +414,6 @@ def _prefill_jit(
424414
sampler: Callable[[Any], Any] | None = None, # pylint: disable=unused-argument
425415
rng: PRNGKeyType | None = None,
426416
slot: int | None = None,
427-
page_state: PageState | None = None,
428417
return_prompt_logp: bool = False,
429418
algorithm: str | None = None,
430419
topk: int | None = None,
@@ -524,7 +513,6 @@ def _prefill_jit(
524513
previous_chunk=previous_chunk,
525514
true_length=true_length,
526515
slot=slot,
527-
page_state=page_state,
528516
)
529517
if return_prompt_logp:
530518
prompt_logp = inference_utils.prompt_logprobs_from_prefill(flat_logits, input_tokens, true_length)
@@ -613,12 +601,6 @@ def prefill(
613601
): # returns (new_prefix, result_tokens)
614602
"""Public API for prefill that updates page state outside JIT."""
615603
# Update page state before JIT call
616-
if self.config.attention == "paged" and self.page_manager is not None and self.page_state is not None:
617-
self.page_state = self.page_manager.update_prefill_pages( # pytype: disable=attribute-error
618-
page_state=self.page_state,
619-
page_group_id=slot,
620-
true_length=true_length,
621-
)
622604

623605
# Sample rng before JIT call
624606
if rng is None:
@@ -639,7 +621,6 @@ def prefill(
639621
audio_masks=audio_masks,
640622
sampler=sampler,
641623
true_length=true_length,
642-
page_state=self.page_state, # Pass current page state
643624
slot=slot,
644625
rng=rng,
645626
return_prompt_logp=return_prompt_logp,
@@ -955,8 +936,6 @@ def generate(
955936
"""Public API for generate that updates page state outside JIT."""
956937

957938
# Update page state before JIT call
958-
if self.page_manager is not None and self.page_state is not None:
959-
self.page_state = self.page_manager.update_decode_pages(self.page_state)
960939

961940
# Sample rng before JIT call
962941
if rng is None:
@@ -969,7 +948,6 @@ def generate(
969948
params=params,
970949
decode_state=decode_state,
971950
sampler=sampler,
972-
page_state=self.page_state,
973951
rng=rng,
974952
algorithm=algorithm,
975953
topk=topk,
@@ -989,7 +967,6 @@ def _generate_jit(
989967
*,
990968
sampler: Callable[[Any], Any] | None = None, # pylint: disable=unused-argument
991969
rng: PRNGKeyType | None = None,
992-
page_state: PageState | None = None,
993970
algorithm: str | None = None,
994971
topk: int | None = None,
995972
nucleus_topp: float | None = None,
@@ -1037,7 +1014,6 @@ def _generate_jit(
10371014
model_mode=MODEL_MODE_AUTOREGRESSIVE,
10381015
rngs={"params": new_rng},
10391016
mutable=["cache"],
1040-
page_state=page_state,
10411017
)
10421018
out_logits = jax.lax.with_sharding_constraint(out_logits, self.replicated_sharding)
10431019
new_cache = jax.lax.with_sharding_constraint(new_vars["cache"], self.kv_cache_shardings)
@@ -1213,7 +1189,6 @@ def _insert_jit(
12131189
decode_state: DecodeState,
12141190
slot: int,
12151191
request_id: uuid.UUID | None = None, # pylint: disable=unused-argument
1216-
page_state_in: PageState | None = None,
12171192
) -> DecodeState:
12181193
"""Insert a single computed prefill cache into KV cache."""
12191194
unboxed_prefix = max_utils.unbox_logicallypartioned(prefix)
@@ -1269,45 +1244,12 @@ def copy(path, partial_cache, full_cache, annotations):
12691244
else:
12701245
raise ValueError(f"We don't have a strategy for inserting {path_key}")
12711246

1272-
if self.config.attention == "paged" and self.page_state is not None:
1273-
1274-
def _copy_paged(path, prefix_cache, decode_state_cache):
1275-
path_key = path[-1].key
1276-
if path_key in ["key_pages", "value_pages"]:
1277-
page_map_for_slot = page_state_in.page_map[slot] # pytype: disable=attribute-error
1278-
num_pages_to_copy = page_state_in.num_pages_used[slot] # pytype: disable=attribute-error
1279-
1280-
def _update_pages(prefix_page_idx, state):
1281-
decode_state_pages, prefix_pages, current_page_map = state
1282-
prefix_page = jax.lax.dynamic_index_in_dim(prefix_pages, prefix_page_idx, axis=1)
1283-
dest_page_idx = current_page_map[prefix_page_idx]
1284-
decode_state_pages = jax.lax.dynamic_update_slice_in_dim(
1285-
decode_state_pages, prefix_page, dest_page_idx, axis=1
1286-
)
1287-
return decode_state_pages, prefix_pages, current_page_map
1288-
1289-
decode_state_cache, _, _ = jax.lax.fori_loop(
1290-
0,
1291-
num_pages_to_copy,
1292-
_update_pages,
1293-
(decode_state_cache, prefix_cache, page_map_for_slot),
1294-
)
1295-
return decode_state_cache
1296-
else:
1297-
raise ValueError(f"We don't have a strategy for inserting {path_key} for paged attention.")
1298-
1299-
inserted_cache = jax.tree_util.tree_map_with_path(
1300-
_copy_paged,
1301-
unboxed_prefix["cache"],
1302-
decode_state["cache"],
1303-
)
1304-
else:
1305-
inserted_cache = jax.tree_util.tree_map_with_path(
1306-
copy,
1307-
unboxed_prefix["cache"],
1308-
decode_state["cache"],
1309-
self.kv_cache_annotations_named,
1310-
)
1247+
inserted_cache = jax.tree_util.tree_map_with_path(
1248+
copy,
1249+
unboxed_prefix["cache"],
1250+
decode_state["cache"],
1251+
self.kv_cache_annotations_named,
1252+
)
13111253

13121254
inserted_logits = jax.lax.dynamic_update_index_in_dim(decode_state["logits"], unboxed_prefix["logits"], slot, 0)
13131255
inserted_next_pos = jax.lax.dynamic_update_index_in_dim(
@@ -1349,23 +1291,11 @@ def insert(
13491291
) -> DecodeState:
13501292
"""Non-JIT wrapper for inserting prefill cache."""
13511293

1352-
current_page_state = None
1353-
if self.config.attention == "paged" and self.page_manager is not None:
1354-
if self.page_state is None:
1355-
self.page_state = self.page_manager.get_initial_page_state()
1356-
current_page_state = self.page_state
1357-
13581294
updated_decode_state = self._insert_jit(
13591295
prefix=prefix,
13601296
decode_state=decode_state,
13611297
slot=slot,
1362-
page_state_in=current_page_state,
13631298
)
1364-
1365-
# Update the PageState after the JIT call
1366-
if self.config.attention == "paged" and self.page_manager is not None and self.page_state is not None:
1367-
new_has_active_page = self.page_state.has_active_page.at[slot].set(True)
1368-
self.page_state = self.page_state.replace(has_active_page=new_has_active_page)
13691299
return updated_decode_state
13701300

13711301
@functools.partial(
@@ -1515,13 +1445,7 @@ def copy(path, partial_cache, full_cache, annotations):
15151445

15161446
def release_pages(self, slot: int):
15171447
"""Releases pages associated with a specific slot (page group) via the PageManager."""
1518-
if self.config.attention != "paged" or self.page_manager is None or self.page_state is None:
1519-
print(f"Warning: release_pages called for slot {slot} but paged attention is not configured or state is missing.")
1520-
return
1521-
new_page_state = self.page_manager.release_pages(
1522-
page_state=self.page_state, page_group_id=slot
1523-
) # pytype: disable=attribute-error
1524-
self.page_state = new_page_state
1448+
print(f"Warning: release_pages called for slot {slot} but paged attention is not configured.")
15251449

15261450
def get_prefix_destination_sharding(self) -> Any:
15271451
return {
@@ -1592,12 +1516,9 @@ def init_decode_state(
15921516
"""Initialises any state which a generation step transforms."""
15931517
if rng is None:
15941518
rng = jax.random.PRNGKey(0)
1595-
page_state = None
1596-
if self.config.attention == "paged" and self.page_manager is not None:
1597-
page_state = self.page_manager.get_initial_page_state() # pytype: disable=attribute-error
15981519

15991520
# pylint: disable=unused-argument
1600-
def init(abstract_params, page_state):
1521+
def init(abstract_params):
16011522
x = jnp.ones(
16021523
(int(self.config.per_device_batch_size * self.mesh.size), 1),
16031524
dtype=jnp.int32,
@@ -1622,7 +1543,6 @@ def init(abstract_params, page_state):
16221543
model_mode=MODEL_MODE_AUTOREGRESSIVE,
16231544
rngs={"params": rng},
16241545
mutable=["cache"],
1625-
page_state=page_state,
16261546
slot=0,
16271547
)
16281548

@@ -1658,7 +1578,7 @@ def init(abstract_params, page_state):
16581578
}
16591579

16601580
with nn_partitioning.axis_rules(self.config.logical_axis_rules):
1661-
abstract_outputs = jax.eval_shape(init, self.abstract_params, page_state)
1581+
abstract_outputs = jax.eval_shape(init, self.abstract_params)
16621582
logical_annotations = nn.get_partition_spec(abstract_outputs)
16631583

16641584
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):

0 commit comments

Comments
 (0)