Skip to content

Enable KV Cache Support for Scanned Decoder Layers and Improve vLLM Integration#3372

Open
khatwanimohit wants to merge 1 commit intomainfrom
mohit/scan_vllm_serve
Open

Enable KV Cache Support for Scanned Decoder Layers and Improve vLLM Integration#3372
khatwanimohit wants to merge 1 commit intomainfrom
mohit/scan_vllm_serve

Conversation

@khatwanimohit
Copy link
Copy Markdown
Collaborator

@khatwanimohit khatwanimohit commented Mar 11, 2026

Description

This PR enables passing KV caches through scanned decoder layers, which is essential for efficient decoding when layer scanning is enabled (a common optimization for large models). It also refines the vLLM adapter and attention layers to better handle sharding and JIT tracing.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Tested vllm_decode with Llama3.1-8B

NEW_MODEL_DESIGN=1 HF_TOKEN=your_hf_token_here python src/maxtext/inference/vllm_decode.py src/maxtext/configs/inference/vllm.yml \
    model_name=llama3.1-8b \
    tokenizer_path=meta-llama/Llama-3.1-8B \
    ici_tensor_parallelism=1 \
    ici_expert_parallelism=1 \
    enable_dp_attention=false \
    hbm_utilization_vllm=0.3 \
    vllm_hf_overrides='{architectures: [\"MaxTextForCausalLM\"]}' \
    prompt=\"Suggest some famous landmarks in London.\" \
    load_parameters_path=gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items \
    scan_layers=true

Output:

I0420 17:34:25.625178 134458074281792 vllm_decode.py:163] Prompt: Suggest some famous landmarks in London., Generated text:  Recall the information that each landmark's design was decided in.
According to popular theories, what gender Victorian people considered fashionable?
According to the historian Walter Scott, how was Queen Victoria dressed at her Coronation?
....

Tested Qwen3-30B using vllm serve on a v5p-8

VLLM_ENABLE_V1_MULTIPROCESSING=0 NEW_MODEL_DESIGN=1 TPU_BACKEND_TYPE=jax vllm serve Qwen/Qwen3-30B-A3B \
  --seed 42 \
  --max-model-len 2048 \
  --no-enable-prefix-caching \
  --disable-log-requests \
  --max-num-batched-tokens 8192 \
  --async-scheduling \
  --tensor_parallel_size 1 \
  --gpu-memory-utilization 0.3 \
  --max-num-seqs 256 \
  --hf_overrides '{"architectures": ["MaxTextForCausalLM"]}' \
  --additional-config='{"sharding":{"sharding_strategy": {"enable_dp_attention": true, "expert_parallelism": 4}}, "maxtext_config": {"model_name": "qwen3-30b-a3b", "scan_layers": true, "log_config":false, "load_parameters_path": "gs://maxtext-model-checkpoints/qwen3-30b/2025-11-11/pathways/scanned/0/items"}}'

Received a response

{
"id": "cmpl-9f0419af227ee518",
"object": "text_completion",
"created": 1773193183,
"model": "Qwen/Qwen3-30B-A3B",
"choices": [
  {
    "index": 0,
    "text": " city in the U.S. state of Washington, and the county seat of King",
    "logprobs": null,
    "finish_reason": "length",
    "stop_reason": null,
    "token_ids": null,
    "prompt_logprobs": null,
    "prompt_token_ids": null
  }
],
"usage": {
  "prompt_tokens": 3,
  "total_tokens": 19,
  "completion_tokens": 16,
  "prompt_tokens_details": null
},
"kv_transfer_params": null
}

Error message when you try to load a scanned checkpoint with unscanned model or vice versa

rank0]: ValueError: Checkpoint loading failed: Checkpoint structure mismatch: 290 of 293 model parameter paths were not found in the checkpoint. This usually means a scanned (scan_layers=True) checkpoint is being loaded with scan_layers=False, or vice versa. Please ensure the checkpoint format matches the scan_layers setting.
[rank0]: Example missing paths:
[rank0]:   decoder.layers_0.mlp.wi_0.kernel
[rank0]:   decoder.layers_0.mlp.wi_1.kernel
[rank0]:   decoder.layers_0.mlp.wo.kernel
[rank0]:   decoder.layers_0.post_self_attention_layer_norm.scale
[rank0]:   decoder.layers_0.pre_self_attention_layer_norm.scale
[rank0]:   ... and 285 more

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NicoGrande
Copy link
Copy Markdown
Collaborator

Thank you for doing this @khatwanimohit !!

CC @ChingTsai and @chishuen who were interested in this feature.

@khatwanimohit khatwanimohit force-pushed the mohit/scan_vllm_serve branch 3 times, most recently from fee2744 to ca03ea7 Compare March 11, 2026 22:43
@khatwanimohit khatwanimohit force-pushed the mohit/scan_vllm_serve branch 2 times, most recently from 794d24d to 0b434cf Compare March 12, 2026 01:33
Copy link
Copy Markdown
Collaborator

@NicoGrande NicoGrande left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extend support for vllm_decode.py?

@khatwanimohit khatwanimohit force-pushed the mohit/scan_vllm_serve branch from 6f348d7 to 1b53c79 Compare April 20, 2026 18:02
Copy link
Copy Markdown
Collaborator

@NicoGrande NicoGrande left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Awesome work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants