Add Qwen3-VL Multimodal Training Support with vLLM 0.20#497
Conversation
- add multimodal preprocessing support for dataset preparation and hidden-state generation - pass prompt and multi_modal_data through offline datagen and vLLM client flows - fix text-only regressions, target-layer alignment, and empty-batch dtype handling - add a configurable 5k Qwen3-VL online training example with runtime notes - extend preprocessing, vLLM client, model utils, and training data tests Signed-off-by: Haoxiang Sun <shx2005@126.com>
Signed-off-by: Haoxiang Sun <shx2005@126.com>
Signed-off-by: Haoxiang Sun <shx2005@126.com>
Signed-off-by: Haoxiang Sun <shx2005@126.com>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
WalkthroughThis PR introduces end-to-end multimodal training support for Qwen3-VL-4B, adding a new online training script that orchestrates data preprocessing, vLLM server setup with hidden-state extraction, and distributed training. The implementation extends data preprocessing to normalize multimodal content and load images, updates vLLM integration to handle chat completions and multimodal models, and threads hidden-states dtype throughout the training pipeline while supporting optional message-based hidden-state generation. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant PrepScript as scripts/prepare_data.py
participant DataPrep as preprocessing.py
participant vLLMServer as vLLM Server
participant TrainScript as scripts/train.py
participant Trainer as torch.distributed training
User->>PrepScript: --multimodal flag
PrepScript->>DataPrep: load_and_preprocess_dataset(is_multimodal=True)
DataPrep->>DataPrep: load AutoProcessor + tokenizer
DataPrep->>DataPrep: normalize multimodal content
DataPrep->>DataPrep: extract & load images
DataPrep->>DataPrep: expand loss_mask for tokens
DataPrep-->>PrepScript: prepared JSONL dataset
User->>vLLMServer: launch_vllm.py (multimodal model)
vLLMServer->>vLLMServer: detect multimodal config
vLLMServer->>vLLMServer: enforce eager mode
vLLMServer->>vLLMServer: flatten text_config
vLLMServer-->>User: /health endpoint ready
User->>TrainScript: torchrun train.py
TrainScript->>DataPrep: create_collate_fn(hidden_states_dtype)
TrainScript->>vLLMServer: generate_hidden_states(messages=[...])
vLLMServer->>vLLMServer: route to chat.completions
vLLMServer->>vLLMServer: truncate prefix if needed
vLLMServer-->>TrainScript: hidden_states tensor
TrainScript->>Trainer: batched data with hidden_states
Trainer->>Trainer: forward pass with speculative decoding
Trainer-->>User: checkpoint saved
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
Merge ProtectionsYour pull request matches the following merge protections and will not be merged until they are valid. 🔴 Require two reviewsWaiting for
This rule is failing.PRs labelled "two-reviews" must have at least two approving reviews before merging.
|
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
scripts/data_generation_offline.py (1)
297-309:⚠️ Potential issue | 🟠 Major | ⚡ Quick winValidation logic is incompatible with multimodal prefix-truncated token outputs.
With
messagesenabled, token alignment may be prefix-truncated, but validation still enforces exacttoken_ids == input_ids. This can falsely fail valid multimodal outputs under--validate-outputs.Suggested fix
-def check_safetensors_file(path: Path, tokens: list[int]): +def check_safetensors_file( + path: Path, + tokens: list[int], + allow_prefix_truncation: bool = False, +): with safe_open(path, "pt") as f: t_ids = f.get_tensor("token_ids").tolist() - if t_ids != tokens: - raise ValueError( - f"Token ids in {path} don't match expected token ids {tokens}" - ) + if allow_prefix_truncation: + if len(t_ids) > len(tokens) or t_ids != tokens[-len(t_ids) :]: + raise ValueError( + f"Token ids in {path} are not a suffix of expected token ids" + ) + elif t_ids != tokens: + raise ValueError( + f"Token ids in {path} don't match expected token ids {tokens}" + ) @@ if validate_outputs: await asyncio.to_thread( - check_safetensors_file, target_hidden_states_path, input_ids + check_safetensors_file, + target_hidden_states_path, + input_ids, + messages is not None, )Also applies to: 316-319
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@scripts/data_generation_offline.py` around lines 297 - 309, Validation currently requires exact equality between token_ids and input_ids even when messages is used, which breaks for multimodal prefix-truncated outputs; update the validation in the code around generate_hidden_states_async and the subsequent output-checking (the blocks that compare token_ids and input_ids) to handle prefix-truncation when messages is present by accepting token_ids that match the tail of input_ids (i.e., if messages is not None then assert input_ids.endswith(token_ids) or input_ids[-len(token_ids):] == token_ids), otherwise keep the existing exact equality check.
🧹 Nitpick comments (1)
scripts/train.py (1)
64-73: ⚡ Quick winRemove implicit global
argsdependency fromsetup_dataloader.
setup_dataloaderstill readsargs.total_seq_leninstead of using an explicit parameter. That makes the function fragile outside the CLI entry flow.Suggested refactor
def setup_dataloader( dataset: BaseDataset, world_size: int, local_rank: int, + total_seq_len: int, hidden_size: int, hidden_states_dtype: torch.dtype, @@ pin_memory=True, collate_fn=create_collate_fn( - args.total_seq_len, + total_seq_len, hidden_size, hidden_states_dtype=hidden_states_dtype, preprocess=preprocess, ), @@ train_loader = setup_dataloader( train_dataset, world_size, local_rank, + args.total_seq_len, transformer_layer_config.hidden_size, hidden_states_dtype, @@ val_loader = setup_dataloader( val_dataset, world_size, local_rank, + args.total_seq_len, transformer_layer_config.hidden_size, hidden_states_dtype,Also applies to: 100-105, 343-362
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@scripts/train.py` around lines 64 - 73, The function setup_dataloader currently reads the global args.total_seq_len; change its signature to accept an explicit total_seq_len: int parameter and replace any use of args.total_seq_len inside setup_dataloader with that parameter. Update all callers to pass the CLI/parsed total_seq_len value into setup_dataloader. Additionally search for other functions in this file that reference args.total_seq_len (e.g., the blocks around the other occurrences) and refactor them the same way—add an explicit total_seq_len parameter to those functions and thread the value from the top-level CLI parsing into each call site to remove implicit global args dependency.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.sh`:
- Around line 197-200: The health-check loop that polls
http://localhost:${VLLM_PORT}/health should fail fast: add a startup deadline
(e.g., read MAX_WAIT_SECONDS or compute a start time and compare) and while
looping check that the vLLM process (${VLLM_PID}) is still alive (use kill -0 or
ps) and that the elapsed time is less than the deadline; if the PID has died or
the deadline is exceeded, log an error and exit non‑zero instead of looping
forever. Update the block around the until/curl loop to use these checks and a
clear nonzero exit path so the script terminates when vLLM never becomes
healthy.
- Around line 65-67: The script sets NUM_TRAIN_GPUS to a fixed default instead
of deriving it from TRAIN_GPUS; change the assignment so that when
NUM_TRAIN_GPUS is unset it is computed from the TRAIN_GPUS string by counting
the comma-separated GPU entries (i.e., split TRAIN_GPUS on commas and count
elements) and assign that count to NUM_TRAIN_GPUS; update the variables
VLLM_GPUS, TRAIN_GPUS, NUM_TRAIN_GPUS block so NUM_TRAIN_GPUS is conditional on
the computed length of TRAIN_GPUS rather than a hardcoded value (refer to the
TRAIN_GPUS and NUM_TRAIN_GPUS variables in the diff).
In `@src/speculators/data_generation/preprocessing.py`:
- Around line 250-252: In _preprocess_batch_multimodal, when images =
_extract_processor_images_from_conversation(normalized_conv) yields no images
the code currently returns input_ids, loss_mask early without applying
max_length truncation; change that path to apply the same max_length clipping
used in the assistant-mask branch: truncate input_ids to max_length and truncate
loss_mask to the same length (and update any seq_len/attention-related variables
accordingly) before returning so text-only rows produced while processor is set
never exceed the configured max_length.
In `@src/speculators/models/utils.py`:
- Line 24: The code currently always calls
get_verifier_config(verifier_name_or_path) to set num_layers (and related config
usage in the 26-41 region), which makes failures fatal even when the caller
provided explicit target_layer_ids; change the logic so you only load the
verifier config when target_layer_ids is None (i.e., if target_layer_ids is
provided, skip get_verifier_config entirely). For the branch that does need the
config, wrap get_verifier_config(verifier_name_or_path) in a try/except and
surface a clear error only for that branch (or fall back to a safe default), and
ensure references to num_layers, output_attentions, etc., are only derived from
the config when the config lookup succeeded.
---
Outside diff comments:
In `@scripts/data_generation_offline.py`:
- Around line 297-309: Validation currently requires exact equality between
token_ids and input_ids even when messages is used, which breaks for multimodal
prefix-truncated outputs; update the validation in the code around
generate_hidden_states_async and the subsequent output-checking (the blocks that
compare token_ids and input_ids) to handle prefix-truncation when messages is
present by accepting token_ids that match the tail of input_ids (i.e., if
messages is not None then assert input_ids.endswith(token_ids) or
input_ids[-len(token_ids):] == token_ids), otherwise keep the existing exact
equality check.
---
Nitpick comments:
In `@scripts/train.py`:
- Around line 64-73: The function setup_dataloader currently reads the global
args.total_seq_len; change its signature to accept an explicit total_seq_len:
int parameter and replace any use of args.total_seq_len inside setup_dataloader
with that parameter. Update all callers to pass the CLI/parsed total_seq_len
value into setup_dataloader. Additionally search for other functions in this
file that reference args.total_seq_len (e.g., the blocks around the other
occurrences) and refactor them the same way—add an explicit total_seq_len
parameter to those functions and thread the value from the top-level CLI parsing
into each call site to remove implicit global args dependency.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 83b2d03e-ade5-4d53-8192-bc5aae1f48d8
📒 Files selected for processing (15)
examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.shscripts/data_generation_offline.pyscripts/launch_vllm.pyscripts/prepare_data.pyscripts/train.pysrc/speculators/data_generation/preprocessing.pysrc/speculators/data_generation/vllm_client.pysrc/speculators/models/utils.pysrc/speculators/train/data.pytests/integration/datagen/test_preprocessing.pytests/unit/convert/test_eagle3_converter.pytests/unit/data_generation/__init__.pytests/unit/data_generation/test_vllm_client.pytests/unit/models/test_utils.pytests/unit/train/test_data.py
Signed-off-by: Haoxiang Sun <shx2005@126.com>
|
@shanjiaz I’ve finished the single image-text testing and code cleanup. When you have time, would you mind taking a look? I’d be very grateful for any feedback or suggestions. The commit ended up being a bit larger than I initially expected, mainly because vLLM 0.20 changed the multimodal message format. To address that, I added some compatibility handling and extra validation to make sure the multimodal path works correctly without affecting the existing text-only path. Thank you very much for your patience, guidance, and continued support. I really appreciate all your help throughout this process. Please feel free to let me know if anything looks unclear or if there is anything you would like me to adjust. |
Thank you so much for taking this up! Will take a look. : ) |
Purpose
Add Qwen3-VL Multimodal Training Support with vLLM 0.20
Description
This PR adds
prepare_data.py --multimodalfor single-image image-text preprocessing. In multimodal mode, preprocessing uses the model processor to render and tokenize the chat prompt, expands image placeholders with the actual processor image inputs so token lengths match vLLM, keeps the existing training fields (input_ids,loss_mask,seq_len), and preserves normalizedmessagesfor vLLM chat-based hidden-state generation.The offline datagen path and online on-missing generation path now carry those
messagesalongside the preprocessedinput_ids. Whenmessagesare present, the vLLM client converts processor-style image parts into OpenAI-compatible chat content and sends the request through the chat completions API withreturn_token_idsandadd_generation_prompt=false. Text-only data remains on the existing token-id completions path.For vLLM 0.20 chat requests, the returned prompt token IDs can correspond to the fully rendered chat prompt. The multimodal path therefore accepts the preprocessed
input_idsas a prefix match and trims the saved hidden states back to the preprocessed sequence length.This PR also includes follow-up fixes found during validation:
--on-missing skipThis commit ended up larger than expected, and progress was slower than planned, because vLLM 0.20 changed the multimodal message format and surfaced prompt-token inconsistencies during validation. Resolving those mismatches required additional compatibility handling, regression coverage, and a runnable Qwen3-VL 5k online training example. I also validated that the same flow scales from the 5k example dataset to the 48k LLaVA-CoT variant.
Related Issue
#290
Tests
Passed:
bash examples/train/eagle3_qwen3_vl_4b_llava_cot_5k_online.shbash examples/train/eagle3_qwen3_8b_sharegpt_online_5k.sh(regression test for the existing text-only path)make quality5k Qwen3-VL online example:
Timing from an observed run on 4x NVIDIA GeForce RTX 5090 32GB GPUs(vLLM on GPUs 0,1 and training on GPUs 2,3):Data preprocessing: 460 seconds (7 mins 40 secs)vLLM server startup: 45 secondsTraining (5 epochs): 1110 seconds (18 mins 30 secs)Total (prepare_data start to checkpoint save): 1615 seconds (26 mins 55 secs)48k Qwen3-VL online run (only 1 epoch):
val/loss_epoch=6.894val/full_acc_0_epoch=65.5%val/full_acc_1_epoch=40.6%val/full_acc_2_epoch=25.3%The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
The test plan/results, such as providing test command and pasting the results.
(Optional) The necessary documentation update.
I (a human) have written or reviewed the code in this pr to the best of my ability.