Skip to content

[Feat] Vllm Dumper#1507

Draft
h-guo18 wants to merge 3 commits into
mainfrom
haoguo/vllm-dumper
Draft

[Feat] Vllm Dumper#1507
h-guo18 wants to merge 3 commits into
mainfrom
haoguo/vllm-dumper

Conversation

@h-guo18

@h-guo18 h-guo18 commented May 16, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: ?

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A
  • Did you get Claude approval on this PR?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added example script for extracting hidden states from conversational data using speculative decoding with configurable runtime parameters.
  • Tests

    • Added parity test validating hidden state extraction consistency and numerical accuracy across implementations.

Review Change Stack

@copy-pr-bot

copy-pr-bot Bot commented May 16, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented May 16, 2026

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

This PR adds a vLLM-based hidden states extraction script and a corresponding parity test. The script loads conversations, tokenizes with loss masks, uses vLLM speculative decoding to extract hidden states, reshapes outputs, and saves results as .pt files. The test validates that vLLM and HuggingFace implementations produce equivalent outputs.

Changes

vLLM Hidden States Collection and Parity Testing

Layer / File(s) Summary
Module setup and CLI argument parsing
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
Module docstring, template preprocessing constant, and parse_args() function defining model/data paths, sequence-length filters, debug limit, and vLLM runtime flags (tensor parallelism, GPU memory, enforce-eager).
Data loading and tokenization preprocessing
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
Load conversations from jsonl, filter existing outputs, load model config and resolve auxiliary layer indices, initialize tokenizer with optional chat template override, tokenize with loss masks, and filter by sequence length bounds.
vLLM generation with speculative decoding
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
Configure vLLM runtime options (GPU memory, enforce-eager), instantiate LLM with speculative decoding and ExampleHiddenStatesConnector, set tensor parallel size, and execute batch generation.
Output processing and .pt file persistence
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
Read vLLM hidden states from safetensors, slice last-layer and reshape auxiliary-layer captures, align loss masks to token sequence length, save per-conversation .pt files with input_ids, hidden_states, aux_hidden_states, loss_mask, and conversation_id, and report success count.
Parity test utilities and validation
tests/examples/speculative_decoding/test_collect_hidden_states_parity.py
Cosine similarity helper and temporary output directory fixture; execute both vLLM and HuggingFace scripts on shared test data, assert exact equality for input_ids and loss_mask, and verify shape match with cosine similarity threshold for hidden state tensors.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title '[Feat] Vllm Dumper' is vague and generic. While it indicates a feature addition for vLLM, it does not clearly describe what the dumper does or its primary purpose (extracting hidden states for speculative decoding). Consider using a more specific title such as '[Feat] Add vLLM hidden states extractor for speculative decoding' to better convey the main functionality.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No critical security anti-patterns detected. Example code uses safe_open. Test torch.load calls exempt per SECURITY.md. trust_remote_code is configurable parameter, not hardcoded.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/vllm-dumper

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov

codecov Bot commented May 16, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 3.84615% with 125 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.75%. Comparing base (7038dec) to head (864dc89).

Files with missing lines Patch % Lines
.../torch/speculative/plugins/hf_streaming_dataset.py 0.00% 125 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1507      +/-   ##
==========================================
- Coverage   76.93%   76.75%   -0.19%     
==========================================
  Files         474      475       +1     
  Lines       51506    51635     +129     
==========================================
+ Hits        39625    39630       +5     
- Misses      11881    12005     +124     
Flag Coverage Δ
unit 52.51% <3.84%> (-0.14%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 self-assigned this May 18, 2026
h-guo18 added 3 commits May 18, 2026 01:25
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/vllm-dumper branch from 0761d1b to 4252397 Compare May 18, 2026 01:30
@github-actions

github-actions Bot commented May 18, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1507/

Built to branch gh-pages at 2026-05-18 02:05 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@h-guo18 h-guo18 changed the title add vllm dumper [Feat] Streaming Dataset, Vllm Dumper May 18, 2026
@h-guo18 h-guo18 force-pushed the haoguo/vllm-dumper branch from e5c46c7 to 2bc8e4b Compare May 18, 2026 02:09
@h-guo18 h-guo18 changed the title [Feat] Streaming Dataset, Vllm Dumper [Feat] Vllm Dumper May 20, 2026
@h-guo18 h-guo18 marked this pull request as ready for review May 26, 2026 21:01
@h-guo18 h-guo18 requested a review from a team as a code owner May 26, 2026 21:01
@h-guo18 h-guo18 requested a review from kevalmorabia97 May 26, 2026 21:01
@h-guo18 h-guo18 marked this pull request as draft May 26, 2026 21:03

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`:
- Around line 242-247: Check and guard for output.kv_transfer_params being None
before calling .get(): inside the loop where hidden_states_path =
output.kv_transfer_params.get("hidden_states_path") is used, first verify that
output.kv_transfer_params is not None (and is a dict-like object) and if it is
None, print the same warning referencing conversation_id and continue; otherwise
retrieve hidden_states_path and proceed as before.

In `@tests/examples/speculative_decoding/test_collect_hidden_states_parity.py`:
- Around line 96-97: The test unsafely deserializes artifacts using
torch.load(..., weights_only=False) for variables pt_hf and pt_vl; change those
calls to torch.load(..., weights_only=True) to avoid executing arbitrary pickled
objects (or, if full object deserialization is genuinely required, add an
explicit inline comment documenting the security rationale and why
weights_only=False is necessary). Update the two calls that set pt_hf and pt_vl
accordingly so the test only loads tensor weights unless a documented exception
is provided.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d92dafc3-8572-4c12-aacd-7afa56e100f3

📥 Commits

Reviewing files that changed from the base of the PR and between 7038dec and 2bc8e4b.

📒 Files selected for processing (2)
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py
  • tests/examples/speculative_decoding/test_collect_hidden_states_parity.py

Comment on lines +242 to +247
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
if hidden_states_path is None:
print(
f"Warning: no hidden_states_path for conversation {conversation_id}, skipping"
)
continue

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Guard against kv_transfer_params being None.

If output.kv_transfer_params is None (e.g., when kv_transfer is not configured or an error occurs), calling .get() on it will raise an AttributeError. Consider checking for None before accessing the dictionary method.

Proposed fix
-            hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
-            if hidden_states_path is None:
+            kv_params = getattr(output, "kv_transfer_params", None) or {}
+            hidden_states_path = kv_params.get("hidden_states_path")
+            if not hidden_states_path:
                 print(
                     f"Warning: no hidden_states_path for conversation {conversation_id}, skipping"
                 )
                 continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
if hidden_states_path is None:
print(
f"Warning: no hidden_states_path for conversation {conversation_id}, skipping"
)
continue
kv_params = getattr(output, "kv_transfer_params", None) or {}
hidden_states_path = kv_params.get("hidden_states_path")
if not hidden_states_path:
print(
f"Warning: no hidden_states_path for conversation {conversation_id}, skipping"
)
continue
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_vllm.py`
around lines 242 - 247, Check and guard for output.kv_transfer_params being None
before calling .get(): inside the loop where hidden_states_path =
output.kv_transfer_params.get("hidden_states_path") is used, first verify that
output.kv_transfer_params is not None (and is a dict-like object) and if it is
None, print the same warning referencing conversation_id and continue; otherwise
retrieve hidden_states_path and proceed as before.

Comment on lines +96 to +97
pt_hf = torch.load(f_hf, map_location="cpu", weights_only=False)
pt_vl = torch.load(f_vl, map_location="cpu", weights_only=False)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify there are no explicit unsafe torch.load calls in Python files.
rg -nP --type=py 'torch\.load\([^)]*weights_only\s*=\s*False' -C2

Repository: NVIDIA/Model-Optimizer

Length of output: 2186


Fix unsafe test artifact deserialization in speculative decoding parity test

Line 96-97 use torch.load(..., weights_only=False) without an inline documented exception rationale; update to weights_only=True (or add an explicit safety justification if full deserialization is required).

🔒 Suggested fix
-        pt_hf = torch.load(f_hf, map_location="cpu", weights_only=False)
-        pt_vl = torch.load(f_vl, map_location="cpu", weights_only=False)
+        pt_hf = torch.load(f_hf, map_location="cpu", weights_only=True)
+        pt_vl = torch.load(f_vl, map_location="cpu", weights_only=True)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/examples/speculative_decoding/test_collect_hidden_states_parity.py`
around lines 96 - 97, The test unsafely deserializes artifacts using
torch.load(..., weights_only=False) for variables pt_hf and pt_vl; change those
calls to torch.load(..., weights_only=True) to avoid executing arbitrary pickled
objects (or, if full object deserialization is genuinely required, add an
explicit inline comment documenting the security rationale and why
weights_only=False is necessary). Update the two calls that set pt_hf and pt_vl
accordingly so the test only loads tensor weights unless a documented exception
is provided.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant