Skip to content

Commit 98e32ac

Browse files
committed
add vllm dumper
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 7038dec commit 98e32ac

1 file changed

Lines changed: 273 additions & 0 deletions

File tree

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Extract hidden states from an LLM using vLLM's ExampleHiddenStatesConnector."""
17+
18+
import argparse
19+
import tempfile
20+
from pathlib import Path
21+
22+
import torch
23+
from common import (
24+
add_answer_only_loss_args,
25+
add_aux_layers_args,
26+
load_chat_template,
27+
resolve_aux_layers,
28+
tokenize_with_loss_mask,
29+
verify_generation_tags,
30+
)
31+
from datasets import load_dataset
32+
from safetensors import safe_open
33+
from tqdm import tqdm
34+
from transformers import AutoConfig, AutoTokenizer
35+
from vllm import LLM, SamplingParams
36+
37+
REMOVE_THINK_CHAT_TEMPLATE = (
38+
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
39+
)
40+
41+
42+
def parse_args() -> argparse.Namespace:
43+
parser = argparse.ArgumentParser(
44+
description="Collect hidden states from conversations using vLLM."
45+
)
46+
parser.add_argument("--model", type=str, required=True, help="Name or path of the model.")
47+
parser.add_argument(
48+
"--input-data",
49+
type=Path,
50+
required=True,
51+
help="Path to a .jsonl file or directory containing .jsonl files.",
52+
)
53+
parser.add_argument(
54+
"--output-dir",
55+
type=Path,
56+
required=True,
57+
help="Directory to save hidden states as .pt files.",
58+
)
59+
parser.add_argument(
60+
"--tp",
61+
type=int,
62+
default=1,
63+
help="Tensor parallel size. Defaults to 1.",
64+
)
65+
parser.add_argument(
66+
"--max-seq-len",
67+
type=int,
68+
default=3072,
69+
help="Maximum number of tokens per conversation. Longer ones are skipped. Defaults to 3072.",
70+
)
71+
parser.add_argument(
72+
"--min-seq-len",
73+
type=int,
74+
default=10,
75+
help="Minimum number of tokens per conversation. Shorter ones are skipped. Defaults to 10.",
76+
)
77+
parser.add_argument(
78+
"--debug-max-num-conversations",
79+
type=int,
80+
default=None,
81+
help="For debugging: limit total conversations processed.",
82+
)
83+
parser.add_argument(
84+
"--trust_remote_code",
85+
action="store_true",
86+
help="Set trust_remote_code for Huggingface models and tokenizers",
87+
)
88+
add_aux_layers_args(parser)
89+
add_answer_only_loss_args(parser)
90+
return parser.parse_args()
91+
92+
93+
def main(args: argparse.Namespace) -> None:
94+
# Load conversations
95+
if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"):
96+
dataset = load_dataset("json", data_files=str(args.input_data), split="train")
97+
elif args.input_data.is_dir():
98+
dataset = load_dataset(
99+
"json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train"
100+
)
101+
else:
102+
raise ValueError(
103+
f"input_data must be a .jsonl file or directory of .jsonl files, got: {args.input_data}"
104+
)
105+
print(f"Loaded {len(dataset)} conversations from {args.input_data}")
106+
107+
args.output_dir.mkdir(parents=True, exist_ok=True)
108+
109+
# Skip already processed conversations
110+
def keep_conversation(entry):
111+
conversation_id = entry.get("conversation_id", entry.get("uuid", None))
112+
assert conversation_id is not None, "Each entry must have a conversation_id or uuid field"
113+
return not (args.output_dir / f"{conversation_id}.pt").exists()
114+
115+
original_num = len(dataset)
116+
dataset = dataset.filter(keep_conversation)
117+
print(f"Removed {original_num - len(dataset)} already-processed conversations")
118+
119+
if args.debug_max_num_conversations is not None:
120+
dataset = dataset.select(range(args.debug_max_num_conversations))
121+
122+
# Determine aux layer indices per --aux-layers flag.
123+
# Convention bridge: resolve_aux_layers returns 0-based transformer layer
124+
# IDs (HF: outputs.hidden_states[lid + 1] = output of layer lid). vLLM's
125+
# `aux_hidden_state_layers` is checked against `idx + 1` after layer idx,
126+
# so the index there is also "lid + 1" — i.e. shift HF's lid by +1.
127+
# Last-layer capture: HF puts the post-final-norm result at hidden_states[N];
128+
# vLLM exposes the same position (idx+1 == N after layer N-1) but stores the
129+
# *pre-norm* residual stream there, which is fine for our consumer below.
130+
hf_config = AutoConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
131+
num_hidden_layers = hf_config.num_hidden_layers
132+
aux_layer_ids = resolve_aux_layers(args, num_hidden_layers)
133+
aux_capture_ids_vllm = [lid + 1 for lid in aux_layer_ids]
134+
# All layers to capture: shifted aux layers + final-layer position N
135+
all_capture_ids = sorted({*aux_capture_ids_vllm, num_hidden_layers})
136+
print(
137+
f"Model has {num_hidden_layers} hidden layers; "
138+
f"aux layer ids (HF 0-based): {aux_layer_ids}, "
139+
f"vLLM capture ids: {all_capture_ids}"
140+
)
141+
142+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
143+
if tokenizer.pad_token is None:
144+
tokenizer.pad_token = tokenizer.eos_token
145+
override_template = load_chat_template(args.chat_template)
146+
if override_template is not None:
147+
tokenizer.chat_template = override_template
148+
if tokenizer.chat_template:
149+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
150+
if args.answer_only_loss:
151+
verify_generation_tags(tokenizer.chat_template)
152+
153+
# Tokenize and filter conversations
154+
token_id_list = []
155+
conversation_ids = []
156+
loss_masks_by_id: dict[str, torch.Tensor] = {}
157+
num_skipped_too_long = 0
158+
num_invalid = 0
159+
160+
for entry in dataset:
161+
conversation_id = entry.get("conversation_id", entry.get("uuid"))
162+
conversations = entry.get("messages") or entry.get("conversations")
163+
if not conversations or not isinstance(conversations, list):
164+
num_invalid += 1
165+
continue
166+
167+
# Single apply_chat_template call produces both input_ids and loss_mask,
168+
# guaranteeing they come from the same tokenization.
169+
input_ids, loss_mask = tokenize_with_loss_mask(
170+
tokenizer, conversations, args.answer_only_loss
171+
)
172+
num_tokens = input_ids.shape[-1]
173+
if num_tokens < args.min_seq_len or num_tokens > args.max_seq_len:
174+
num_skipped_too_long += 1
175+
continue
176+
177+
token_id_list.append(input_ids.squeeze(0).tolist())
178+
conversation_ids.append(conversation_id)
179+
loss_masks_by_id[conversation_id] = loss_mask
180+
181+
print(
182+
f"Tokenized {len(token_id_list)} conversations "
183+
f"(skipped {num_skipped_too_long} by length, {num_invalid} invalid)"
184+
)
185+
186+
if not token_id_list:
187+
print("No conversations to process.")
188+
return
189+
190+
with tempfile.TemporaryDirectory() as tmpdir:
191+
llm = LLM(
192+
model=args.model,
193+
speculative_config={
194+
"method": "extract_hidden_states",
195+
"num_speculative_tokens": 1,
196+
"draft_model_config": {
197+
"hf_config": {
198+
"eagle_aux_hidden_state_layer_ids": all_capture_ids,
199+
}
200+
},
201+
},
202+
kv_transfer_config={
203+
"kv_connector": "ExampleHiddenStatesConnector",
204+
"kv_role": "kv_producer",
205+
"kv_connector_extra_config": {
206+
"shared_storage_path": tmpdir,
207+
},
208+
},
209+
tensor_parallel_size=args.tp,
210+
trust_remote_code=args.trust_remote_code,
211+
)
212+
213+
sampling_params = SamplingParams(max_tokens=1)
214+
outputs = llm.generate(
215+
[{"prompt_token_ids": ids} for ids in token_id_list],
216+
sampling_params,
217+
)
218+
219+
num_success = 0
220+
for output, conversation_id in tqdm(
221+
zip(outputs, conversation_ids),
222+
total=len(outputs),
223+
desc="Saving hidden states",
224+
):
225+
hidden_states_path = output.kv_transfer_params.get("hidden_states_path")
226+
if hidden_states_path is None:
227+
print(
228+
f"Warning: no hidden_states_path for conversation {conversation_id}, skipping"
229+
)
230+
continue
231+
232+
with safe_open(hidden_states_path, framework="pt") as f:
233+
token_ids_tensor = f.get_tensor("token_ids")
234+
# Shape from vLLM: [seq_len, num_captured_layers, hidden_dim]
235+
hidden_states_tensor = f.get_tensor("hidden_states")
236+
237+
# Last captured layer (= last model layer N-1) -> output hidden states
238+
# Earlier captured layers -> aux hidden states, concatenated along hidden dim
239+
output_hidden_states = hidden_states_tensor[:, -1, :] # [seq_len, hidden_dim]
240+
aux_hidden_states = hidden_states_tensor[:, :-1, :].reshape(
241+
hidden_states_tensor.shape[0], -1
242+
) # [seq_len, hidden_dim * num_aux_layers]
243+
244+
# Align loss_mask with the token length returned by vLLM: if vLLM
245+
# truncated, truncate; if it somehow grew (shouldn't happen), pad with 1s
246+
# so that tail positions remain trainable under non-answer-only runs.
247+
vllm_seq_len = token_ids_tensor.shape[0]
248+
loss_mask = loss_masks_by_id[conversation_id]
249+
if loss_mask.shape[0] > vllm_seq_len:
250+
loss_mask = loss_mask[:vllm_seq_len]
251+
elif loss_mask.shape[0] < vllm_seq_len:
252+
pad = torch.ones(vllm_seq_len - loss_mask.shape[0], dtype=loss_mask.dtype)
253+
loss_mask = torch.cat([loss_mask, pad], dim=0)
254+
255+
output_file = args.output_dir / f"{conversation_id}.pt"
256+
torch.save(
257+
{
258+
"input_ids": token_ids_tensor.to(torch.int64),
259+
"hidden_states": output_hidden_states,
260+
"aux_hidden_states": aux_hidden_states,
261+
"loss_mask": loss_mask,
262+
"conversation_id": conversation_id,
263+
},
264+
output_file,
265+
)
266+
num_success += 1
267+
268+
print(f"Successfully saved {num_success} / {len(token_id_list)} conversations.")
269+
270+
271+
if __name__ == "__main__":
272+
cli_args = parse_args()
273+
main(cli_args)

0 commit comments

Comments
 (0)