Skip to content

Add Qwen3-VL Multimodal Training Support with vLLM 0.20#497

Draft
shx2005 wants to merge 6 commits into
vllm-project:mainfrom
shx2005:mm-rebase-qwen3vl-origin-main
Draft

Add Qwen3-VL Multimodal Training Support with vLLM 0.20#497
shx2005 wants to merge 6 commits into
vllm-project:mainfrom
shx2005:mm-rebase-qwen3vl-origin-main

Conversation

@shx2005
Copy link
Copy Markdown

@shx2005 shx2005 commented May 1, 2026

Purpose

Add Qwen3-VL Multimodal Training Support with vLLM 0.20

Description

This PR adds prepare_data.py --multimodal for single-image image-text preprocessing. In multimodal mode, preprocessing uses the model processor to render and tokenize the chat prompt, expands image placeholders with the actual processor image inputs so token lengths match vLLM, keeps the existing training fields (input_ids, loss_mask, seq_len), and preserves normalized messages for vLLM chat-based hidden-state generation.

The offline datagen path and online on-missing generation path now carry those messages alongside the preprocessed input_ids. When messages are present, the vLLM client converts processor-style image parts into OpenAI-compatible chat content and sends the request through the chat completions API with return_token_ids and add_generation_prompt=false. Text-only data remains on the existing token-id completions path.

For vLLM 0.20 chat requests, the returned prompt token IDs can correspond to the fully rendered chat prompt. The multimodal path therefore accepts the preprocessed input_ids as a prefix match and trims the saved hidden states back to the preprocessed sequence length.

This PR also includes follow-up fixes found during validation:

  • keep text-only preprocessing free of multimodal columns
  • strip the verifier final layer from explicit target-layer ids before saving draft auxiliary layer ids
  • preserve the configured hidden-state dtype for empty batches created by --on-missing skip

This commit ended up larger than expected, and progress was slower than planned, because vLLM 0.20 changed the multimodal message format and surfaced prompt-token inconsistencies during validation. Resolving those mismatches required additional compatibility handling, regression coverage, and a runnable Qwen3-VL 5k online training example. I also validated that the same flow scales from the 5k example dataset to the 48k LLaVA-CoT variant.

Related Issue

#290

Tests

Passed:

  • bash examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh

  • bash examples/train/eagle3_qwen3_8b_sharegpt_online_5k.sh (regression test for the existing text-only path)

  • make quality

  • 5k Qwen3-VL online example:

    Timing from an observed run on 4x NVIDIA GeForce RTX 5090 32GB GPUs
    (vLLM on GPUs 0,1 and training on GPUs 2,3):
    Data preprocessing: 460 seconds (7 mins 40 secs)
    vLLM server startup: 45 seconds
    Training (5 epochs): 1110 seconds (18 mins 30 secs)
    Total (prepare_data start to checkpoint save): 1615 seconds (26 mins 55 secs)

image
  • 48k Qwen3-VL online run (only 1 epoch):

    • val/loss_epoch=6.894
    • val/full_acc_0_epoch=65.5%
    • val/full_acc_1_epoch=40.6%
    • val/full_acc_2_epoch=25.3%
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".

  • The test plan/results, such as providing test command and pasting the results.

  • (Optional) The necessary documentation update.

  • I (a human) have written or reviewed the code in this pr to the best of my ability.

shx2005 added 4 commits April 29, 2026 11:28
- add multimodal preprocessing support for dataset preparation and hidden-state generation
- pass prompt and multi_modal_data through offline datagen and vLLM client flows
- fix text-only regressions, target-layer alignment, and empty-batch dtype handling
- add a configurable 5k Qwen3-VL online training example with runtime notes
- extend preprocessing, vLLM client, model utils, and training data tests

Signed-off-by: Haoxiang Sun <shx2005@126.com>
Signed-off-by: Haoxiang Sun <shx2005@126.com>
Signed-off-by: Haoxiang Sun <shx2005@126.com>
Signed-off-by: Haoxiang Sun <shx2005@126.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 1, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a7385dea-b6e2-4af1-be22-714fa254ea4b

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review

Walkthrough

This PR introduces end-to-end multimodal training support for Qwen3-VL-4B, adding a new online training script that orchestrates data preprocessing, vLLM server setup with hidden-state extraction, and distributed training. The implementation extends data preprocessing to normalize multimodal content and load images, updates vLLM integration to handle chat completions and multimodal models, and threads hidden-states dtype throughout the training pipeline while supporting optional message-based hidden-state generation.

Changes

Cohort / File(s) Summary
New Example Training Script
examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh
Adds executable training script orchestrating online multimodal workflow: downloads dataset shards, normalizes image references into absolute paths, invokes preprocessing with multimodal settings, starts vLLM server with hidden-state output, polls for readiness, and launches distributed training with generate/cache behavior for missing hidden states.
Data Preprocessing & Normalization
src/speculators/data_generation/preprocessing.py
Introduces multimodal preprocessing path with content normalization, image extraction/loading, loss-mask expansion for multimodal tokens, and support for both conversations and messages schema keys. Adds processor-based batch preprocessing and dataset building routes with reduced batch sizes for multimodal inputs.
vLLM Hidden-State Generation
src/speculators/data_generation/vllm_client.py
Adds multimodal chat-message preparation utilities translating processor-style messages into OpenAI/vLLM chat format with image URL conversion. Enhances output extraction with defensive field access, token-id normalization, and safetensors-based prefix trimming. Extends hidden-state generation functions to accept optional messages parameter for chat-completion routing.
vLLM Server Launch
scripts/launch_vllm.py
Detects multimodal model configurations and enforces eager mode with --enforce-eager flag. Refactors layer-count computation to use text_config when available. Expands HF config passed to vLLM by conditionally flattening multimodal/nested text configuration before embedding under speculative config.
Data Generation Pipeline
scripts/data_generation_offline.py
Adds queue-item construction utilities normalizing input_ids to list[int] and optionally carrying through messages. Worker logic extracts and forwards optional messages into hidden-states generation. Adds logging for messages column presence and actual resolved output directory.
Training Data Infrastructure
src/speculators/train/data.py, scripts/train.py
Threads hidden_states_dtype parameter through dataloader setup and collate function. Sets explicit dtypes for hidden_states and verifier_last_hidden_states while keeping index/mask tensors as torch.long. Updates hidden-state generation to defensively handle input_ids and forward optional per-sample messages.
Preprocessing Configuration
scripts/prepare_data.py
Adds --multimodal boolean CLI argument enabling multimodal preprocessing path with message preservation for vLLM hidden-state generation.
Model Utilities
src/speculators/models/utils.py
Updates resolve_target_layer_ids to filter out verifier's final layer when explicitly provided, emitting warning when stripping occurs, clarifying that final layer is consumed separately as verifier_last_hidden_states.
Comprehensive Test Coverage
tests/unit/data_generation/test_vllm_client.py, tests/integration/datagen/test_preprocessing.py, tests/unit/train/test_data.py, tests/unit/models/test_utils.py, tests/unit/convert/test_eagle3_converter.py, tests/unit/data_generation/__init__.py
Adds unit tests for multimodal vLLM client with dummy sync/async implementations, image URL conversion, and prefix truncation. Integration tests validate content normalization, loss-mask expansion, dataset schema compatibility, and multimodal processor integration. Unit tests verify dtype preservation and layer-id filtering behavior.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant PrepScript as scripts/prepare_data.py
    participant DataPrep as preprocessing.py
    participant vLLMServer as vLLM Server
    participant TrainScript as scripts/train.py
    participant Trainer as torch.distributed training

    User->>PrepScript: --multimodal flag
    PrepScript->>DataPrep: load_and_preprocess_dataset(is_multimodal=True)
    DataPrep->>DataPrep: load AutoProcessor + tokenizer
    DataPrep->>DataPrep: normalize multimodal content
    DataPrep->>DataPrep: extract & load images
    DataPrep->>DataPrep: expand loss_mask for tokens
    DataPrep-->>PrepScript: prepared JSONL dataset
    
    User->>vLLMServer: launch_vllm.py (multimodal model)
    vLLMServer->>vLLMServer: detect multimodal config
    vLLMServer->>vLLMServer: enforce eager mode
    vLLMServer->>vLLMServer: flatten text_config
    vLLMServer-->>User: /health endpoint ready

    User->>TrainScript: torchrun train.py
    TrainScript->>DataPrep: create_collate_fn(hidden_states_dtype)
    TrainScript->>vLLMServer: generate_hidden_states(messages=[...])
    vLLMServer->>vLLMServer: route to chat.completions
    vLLMServer->>vLLMServer: truncate prefix if needed
    vLLMServer-->>TrainScript: hidden_states tensor
    TrainScript->>Trainer: batched data with hidden_states
    Trainer->>Trainer: forward pass with speculative decoding
    Trainer-->>User: checkpoint saved
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related issues

  • Implements multimodal training support as described in RFC issue, providing end-to-end data preprocessing, vLLM integration, and training pipeline modifications for Qwen3-VL-4B models.

Possibly related PRs

  • PR #433: Overlapping modifications to data-generation pipeline and hidden-state generation tooling in scripts and vLLM client.
  • PR #378: Related changes to hidden-states dtype handling in training script and vLLM launch behavior configuration.
  • PR #436: Both add orchestration scripts in examples/train/ coordinating prepare_data.py, launch_vllm.py, and distributed training workflows.

Suggested labels

enhancement, training, data-generation, two-reviews

Suggested reviewers

  • shanjiaz
  • dsikka
  • fynnsu
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.27% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding Qwen3-VL multimodal training support with vLLM 0.20, which aligns with the core objectives of enabling multimodal preprocessing and training flows.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description clearly describes the purpose (adding Qwen3-VL multimodal training support), the implementation approach (multimodal preprocessing, vLLM integration, message preservation), related fixes, test results, and references an existing issue.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@mergify mergify Bot added the documentation Improvements or additions to documentation label May 1, 2026
@shx2005
Copy link
Copy Markdown
Author

shx2005 commented May 1, 2026

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 1, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@mergify
Copy link
Copy Markdown

mergify Bot commented May 1, 2026

Merge Protections

Your pull request matches the following merge protections and will not be merged until they are valid.

🔴 Require two reviews

Waiting for

  • #approved-reviews-by >= 2
This rule is failing.

PRs labelled "two-reviews" must have at least two approving reviews before merging.

  • #approved-reviews-by >= 2

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
scripts/data_generation_offline.py (1)

297-309: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validation logic is incompatible with multimodal prefix-truncated token outputs.

With messages enabled, token alignment may be prefix-truncated, but validation still enforces exact token_ids == input_ids. This can falsely fail valid multimodal outputs under --validate-outputs.

Suggested fix
-def check_safetensors_file(path: Path, tokens: list[int]):
+def check_safetensors_file(
+    path: Path,
+    tokens: list[int],
+    allow_prefix_truncation: bool = False,
+):
     with safe_open(path, "pt") as f:
         t_ids = f.get_tensor("token_ids").tolist()
-        if t_ids != tokens:
-            raise ValueError(
-                f"Token ids in {path} don't match expected token ids {tokens}"
-            )
+        if allow_prefix_truncation:
+            if len(t_ids) > len(tokens) or t_ids != tokens[-len(t_ids) :]:
+                raise ValueError(
+                    f"Token ids in {path} are not a suffix of expected token ids"
+                )
+        elif t_ids != tokens:
+            raise ValueError(
+                f"Token ids in {path} don't match expected token ids {tokens}"
+            )
@@
                 if validate_outputs:
                     await asyncio.to_thread(
-                        check_safetensors_file, target_hidden_states_path, input_ids
+                        check_safetensors_file,
+                        target_hidden_states_path,
+                        input_ids,
+                        messages is not None,
                     )

Also applies to: 316-319

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/data_generation_offline.py` around lines 297 - 309, Validation
currently requires exact equality between token_ids and input_ids even when
messages is used, which breaks for multimodal prefix-truncated outputs; update
the validation in the code around generate_hidden_states_async and the
subsequent output-checking (the blocks that compare token_ids and input_ids) to
handle prefix-truncation when messages is present by accepting token_ids that
match the tail of input_ids (i.e., if messages is not None then assert
input_ids.endswith(token_ids) or input_ids[-len(token_ids):] == token_ids),
otherwise keep the existing exact equality check.
🧹 Nitpick comments (1)
scripts/train.py (1)

64-73: ⚡ Quick win

Remove implicit global args dependency from setup_dataloader.

setup_dataloader still reads args.total_seq_len instead of using an explicit parameter. That makes the function fragile outside the CLI entry flow.

Suggested refactor
 def setup_dataloader(
     dataset: BaseDataset,
     world_size: int,
     local_rank: int,
+    total_seq_len: int,
     hidden_size: int,
     hidden_states_dtype: torch.dtype,
@@
         pin_memory=True,
         collate_fn=create_collate_fn(
-            args.total_seq_len,
+            total_seq_len,
             hidden_size,
             hidden_states_dtype=hidden_states_dtype,
             preprocess=preprocess,
         ),
@@
     train_loader = setup_dataloader(
         train_dataset,
         world_size,
         local_rank,
+        args.total_seq_len,
         transformer_layer_config.hidden_size,
         hidden_states_dtype,
@@
     val_loader = setup_dataloader(
         val_dataset,
         world_size,
         local_rank,
+        args.total_seq_len,
         transformer_layer_config.hidden_size,
         hidden_states_dtype,

Also applies to: 100-105, 343-362

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@scripts/train.py` around lines 64 - 73, The function setup_dataloader
currently reads the global args.total_seq_len; change its signature to accept an
explicit total_seq_len: int parameter and replace any use of args.total_seq_len
inside setup_dataloader with that parameter. Update all callers to pass the
CLI/parsed total_seq_len value into setup_dataloader. Additionally search for
other functions in this file that reference args.total_seq_len (e.g., the blocks
around the other occurrences) and refactor them the same way—add an explicit
total_seq_len parameter to those functions and thread the value from the
top-level CLI parsing into each call site to remove implicit global args
dependency.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh`:
- Around line 197-200: The health-check loop that polls
http://localhost:${VLLM_PORT}/health should fail fast: add a startup deadline
(e.g., read MAX_WAIT_SECONDS or compute a start time and compare) and while
looping check that the vLLM process (${VLLM_PID}) is still alive (use kill -0 or
ps) and that the elapsed time is less than the deadline; if the PID has died or
the deadline is exceeded, log an error and exit non‑zero instead of looping
forever. Update the block around the until/curl loop to use these checks and a
clear nonzero exit path so the script terminates when vLLM never becomes
healthy.
- Around line 65-67: The script sets NUM_TRAIN_GPUS to a fixed default instead
of deriving it from TRAIN_GPUS; change the assignment so that when
NUM_TRAIN_GPUS is unset it is computed from the TRAIN_GPUS string by counting
the comma-separated GPU entries (i.e., split TRAIN_GPUS on commas and count
elements) and assign that count to NUM_TRAIN_GPUS; update the variables
VLLM_GPUS, TRAIN_GPUS, NUM_TRAIN_GPUS block so NUM_TRAIN_GPUS is conditional on
the computed length of TRAIN_GPUS rather than a hardcoded value (refer to the
TRAIN_GPUS and NUM_TRAIN_GPUS variables in the diff).

In `@src/speculators/data_generation/preprocessing.py`:
- Around line 250-252: In _preprocess_batch_multimodal, when images =
_extract_processor_images_from_conversation(normalized_conv) yields no images
the code currently returns input_ids, loss_mask early without applying
max_length truncation; change that path to apply the same max_length clipping
used in the assistant-mask branch: truncate input_ids to max_length and truncate
loss_mask to the same length (and update any seq_len/attention-related variables
accordingly) before returning so text-only rows produced while processor is set
never exceed the configured max_length.

In `@src/speculators/models/utils.py`:
- Line 24: The code currently always calls
get_verifier_config(verifier_name_or_path) to set num_layers (and related config
usage in the 26-41 region), which makes failures fatal even when the caller
provided explicit target_layer_ids; change the logic so you only load the
verifier config when target_layer_ids is None (i.e., if target_layer_ids is
provided, skip get_verifier_config entirely). For the branch that does need the
config, wrap get_verifier_config(verifier_name_or_path) in a try/except and
surface a clear error only for that branch (or fall back to a safe default), and
ensure references to num_layers, output_attentions, etc., are only derived from
the config when the config lookup succeeded.

---

Outside diff comments:
In `@scripts/data_generation_offline.py`:
- Around line 297-309: Validation currently requires exact equality between
token_ids and input_ids even when messages is used, which breaks for multimodal
prefix-truncated outputs; update the validation in the code around
generate_hidden_states_async and the subsequent output-checking (the blocks that
compare token_ids and input_ids) to handle prefix-truncation when messages is
present by accepting token_ids that match the tail of input_ids (i.e., if
messages is not None then assert input_ids.endswith(token_ids) or
input_ids[-len(token_ids):] == token_ids), otherwise keep the existing exact
equality check.

---

Nitpick comments:
In `@scripts/train.py`:
- Around line 64-73: The function setup_dataloader currently reads the global
args.total_seq_len; change its signature to accept an explicit total_seq_len:
int parameter and replace any use of args.total_seq_len inside setup_dataloader
with that parameter. Update all callers to pass the CLI/parsed total_seq_len
value into setup_dataloader. Additionally search for other functions in this
file that reference args.total_seq_len (e.g., the blocks around the other
occurrences) and refactor them the same way—add an explicit total_seq_len
parameter to those functions and thread the value from the top-level CLI parsing
into each call site to remove implicit global args dependency.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 83b2d03e-ade5-4d53-8192-bc5aae1f48d8

📥 Commits

Reviewing files that changed from the base of the PR and between 4afaf98 and 4d06511.

📒 Files selected for processing (15)
  • examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh
  • scripts/data_generation_offline.py
  • scripts/launch_vllm.py
  • scripts/prepare_data.py
  • scripts/train.py
  • src/speculators/data_generation/preprocessing.py
  • src/speculators/data_generation/vllm_client.py
  • src/speculators/models/utils.py
  • src/speculators/train/data.py
  • tests/integration/datagen/test_preprocessing.py
  • tests/unit/convert/test_eagle3_converter.py
  • tests/unit/data_generation/__init__.py
  • tests/unit/data_generation/test_vllm_client.py
  • tests/unit/models/test_utils.py
  • tests/unit/train/test_data.py

Comment thread examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh
Comment thread examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh
Comment thread src/speculators/data_generation/preprocessing.py
Comment thread src/speculators/models/utils.py
@shx2005 shx2005 mentioned this pull request May 1, 2026
4 tasks
Signed-off-by: Haoxiang Sun <shx2005@126.com>
@shx2005
Copy link
Copy Markdown
Author

shx2005 commented May 1, 2026

@shanjiaz I’ve finished the single image-text testing and code cleanup. When you have time, would you mind taking a look? I’d be very grateful for any feedback or suggestions.

The commit ended up being a bit larger than I initially expected, mainly because vLLM 0.20 changed the multimodal message format. To address that, I added some compatibility handling and extra validation to make sure the multimodal path works correctly without affecting the existing text-only path.

Thank you very much for your patience, guidance, and continued support. I really appreciate all your help throughout this process. Please feel free to let me know if anything looks unclear or if there is anything you would like me to adjust.

@shx2005 shx2005 marked this pull request as ready for review May 1, 2026 12:52
@shanjiaz shanjiaz requested review from fynnsu and shanjiaz May 1, 2026 13:03
@shanjiaz
Copy link
Copy Markdown
Collaborator

shanjiaz commented May 1, 2026

@shanjiaz I’ve finished the single image-text testing and code cleanup. When you have time, would you mind taking a look? I’d be very grateful for any feedback or suggestions.

The commit ended up being a bit larger than I initially expected, mainly because vLLM 0.20 changed the multimodal message format. To address that, I added some compatibility handling and extra validation to make sure the multimodal path works correctly without affecting the existing text-only path.

Thank you very much for your patience, guidance, and continued support. I really appreciate all your help throughout this process. Please feel free to let me know if anything looks unclear or if there is anything you would like me to adjust.

Thank you so much for taking this up! Will take a look. : )

@shx2005 shx2005 marked this pull request as draft May 9, 2026 00:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data-generation documentation Improvements or additions to documentation enhancement New feature or request training two-reviews

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants