|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Extract hidden states from an LLM using vLLM's ExampleHiddenStatesConnector.""" |
| 17 | + |
| 18 | +import argparse |
| 19 | +import tempfile |
| 20 | +from pathlib import Path |
| 21 | + |
| 22 | +import torch |
| 23 | +from common import ( |
| 24 | + add_answer_only_loss_args, |
| 25 | + add_aux_layers_args, |
| 26 | + load_chat_template, |
| 27 | + resolve_aux_layers, |
| 28 | + tokenize_with_loss_mask, |
| 29 | + verify_generation_tags, |
| 30 | +) |
| 31 | +from datasets import load_dataset |
| 32 | +from safetensors import safe_open |
| 33 | +from tqdm import tqdm |
| 34 | +from transformers import AutoConfig, AutoTokenizer |
| 35 | +from vllm import LLM, SamplingParams |
| 36 | + |
| 37 | +REMOVE_THINK_CHAT_TEMPLATE = ( |
| 38 | + "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}" |
| 39 | +) |
| 40 | + |
| 41 | + |
| 42 | +def parse_args() -> argparse.Namespace: |
| 43 | + parser = argparse.ArgumentParser( |
| 44 | + description="Collect hidden states from conversations using vLLM." |
| 45 | + ) |
| 46 | + parser.add_argument("--model", type=str, required=True, help="Name or path of the model.") |
| 47 | + parser.add_argument( |
| 48 | + "--input-data", |
| 49 | + type=Path, |
| 50 | + required=True, |
| 51 | + help="Path to a .jsonl file or directory containing .jsonl files.", |
| 52 | + ) |
| 53 | + parser.add_argument( |
| 54 | + "--output-dir", |
| 55 | + type=Path, |
| 56 | + required=True, |
| 57 | + help="Directory to save hidden states as .pt files.", |
| 58 | + ) |
| 59 | + parser.add_argument( |
| 60 | + "--tp", |
| 61 | + type=int, |
| 62 | + default=1, |
| 63 | + help="Tensor parallel size. Defaults to 1.", |
| 64 | + ) |
| 65 | + parser.add_argument( |
| 66 | + "--max-seq-len", |
| 67 | + type=int, |
| 68 | + default=3072, |
| 69 | + help="Maximum number of tokens per conversation. Longer ones are skipped. Defaults to 3072.", |
| 70 | + ) |
| 71 | + parser.add_argument( |
| 72 | + "--min-seq-len", |
| 73 | + type=int, |
| 74 | + default=10, |
| 75 | + help="Minimum number of tokens per conversation. Shorter ones are skipped. Defaults to 10.", |
| 76 | + ) |
| 77 | + parser.add_argument( |
| 78 | + "--debug-max-num-conversations", |
| 79 | + type=int, |
| 80 | + default=None, |
| 81 | + help="For debugging: limit total conversations processed.", |
| 82 | + ) |
| 83 | + parser.add_argument( |
| 84 | + "--trust_remote_code", |
| 85 | + action="store_true", |
| 86 | + help="Set trust_remote_code for Huggingface models and tokenizers", |
| 87 | + ) |
| 88 | + add_aux_layers_args(parser) |
| 89 | + add_answer_only_loss_args(parser) |
| 90 | + return parser.parse_args() |
| 91 | + |
| 92 | + |
| 93 | +def main(args: argparse.Namespace) -> None: |
| 94 | + # Load conversations |
| 95 | + if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"): |
| 96 | + dataset = load_dataset("json", data_files=str(args.input_data), split="train") |
| 97 | + elif args.input_data.is_dir(): |
| 98 | + dataset = load_dataset( |
| 99 | + "json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train" |
| 100 | + ) |
| 101 | + else: |
| 102 | + raise ValueError( |
| 103 | + f"input_data must be a .jsonl file or directory of .jsonl files, got: {args.input_data}" |
| 104 | + ) |
| 105 | + print(f"Loaded {len(dataset)} conversations from {args.input_data}") |
| 106 | + |
| 107 | + args.output_dir.mkdir(parents=True, exist_ok=True) |
| 108 | + |
| 109 | + # Skip already processed conversations |
| 110 | + def keep_conversation(entry): |
| 111 | + conversation_id = entry.get("conversation_id", entry.get("uuid", None)) |
| 112 | + assert conversation_id is not None, "Each entry must have a conversation_id or uuid field" |
| 113 | + return not (args.output_dir / f"{conversation_id}.pt").exists() |
| 114 | + |
| 115 | + original_num = len(dataset) |
| 116 | + dataset = dataset.filter(keep_conversation) |
| 117 | + print(f"Removed {original_num - len(dataset)} already-processed conversations") |
| 118 | + |
| 119 | + if args.debug_max_num_conversations is not None: |
| 120 | + dataset = dataset.select(range(args.debug_max_num_conversations)) |
| 121 | + |
| 122 | + # Determine aux layer indices per --aux-layers flag. |
| 123 | + # Convention bridge: resolve_aux_layers returns 0-based transformer layer |
| 124 | + # IDs (HF: outputs.hidden_states[lid + 1] = output of layer lid). vLLM's |
| 125 | + # `aux_hidden_state_layers` is checked against `idx + 1` after layer idx, |
| 126 | + # so the index there is also "lid + 1" — i.e. shift HF's lid by +1. |
| 127 | + # Last-layer capture: HF puts the post-final-norm result at hidden_states[N]; |
| 128 | + # vLLM exposes the same position (idx+1 == N after layer N-1) but stores the |
| 129 | + # *pre-norm* residual stream there, which is fine for our consumer below. |
| 130 | + hf_config = AutoConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) |
| 131 | + num_hidden_layers = hf_config.num_hidden_layers |
| 132 | + aux_layer_ids = resolve_aux_layers(args, num_hidden_layers) |
| 133 | + aux_capture_ids_vllm = [lid + 1 for lid in aux_layer_ids] |
| 134 | + # All layers to capture: shifted aux layers + final-layer position N |
| 135 | + all_capture_ids = sorted({*aux_capture_ids_vllm, num_hidden_layers}) |
| 136 | + print( |
| 137 | + f"Model has {num_hidden_layers} hidden layers; " |
| 138 | + f"aux layer ids (HF 0-based): {aux_layer_ids}, " |
| 139 | + f"vLLM capture ids: {all_capture_ids}" |
| 140 | + ) |
| 141 | + |
| 142 | + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) |
| 143 | + if tokenizer.pad_token is None: |
| 144 | + tokenizer.pad_token = tokenizer.eos_token |
| 145 | + override_template = load_chat_template(args.chat_template) |
| 146 | + if override_template is not None: |
| 147 | + tokenizer.chat_template = override_template |
| 148 | + if tokenizer.chat_template: |
| 149 | + tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") |
| 150 | + if args.answer_only_loss: |
| 151 | + verify_generation_tags(tokenizer.chat_template) |
| 152 | + |
| 153 | + # Tokenize and filter conversations |
| 154 | + token_id_list = [] |
| 155 | + conversation_ids = [] |
| 156 | + loss_masks_by_id: dict[str, torch.Tensor] = {} |
| 157 | + num_skipped_too_long = 0 |
| 158 | + num_invalid = 0 |
| 159 | + |
| 160 | + for entry in dataset: |
| 161 | + conversation_id = entry.get("conversation_id", entry.get("uuid")) |
| 162 | + conversations = entry.get("messages") or entry.get("conversations") |
| 163 | + if not conversations or not isinstance(conversations, list): |
| 164 | + num_invalid += 1 |
| 165 | + continue |
| 166 | + |
| 167 | + # Single apply_chat_template call produces both input_ids and loss_mask, |
| 168 | + # guaranteeing they come from the same tokenization. |
| 169 | + input_ids, loss_mask = tokenize_with_loss_mask( |
| 170 | + tokenizer, conversations, args.answer_only_loss |
| 171 | + ) |
| 172 | + num_tokens = input_ids.shape[-1] |
| 173 | + if num_tokens < args.min_seq_len or num_tokens > args.max_seq_len: |
| 174 | + num_skipped_too_long += 1 |
| 175 | + continue |
| 176 | + |
| 177 | + token_id_list.append(input_ids.squeeze(0).tolist()) |
| 178 | + conversation_ids.append(conversation_id) |
| 179 | + loss_masks_by_id[conversation_id] = loss_mask |
| 180 | + |
| 181 | + print( |
| 182 | + f"Tokenized {len(token_id_list)} conversations " |
| 183 | + f"(skipped {num_skipped_too_long} by length, {num_invalid} invalid)" |
| 184 | + ) |
| 185 | + |
| 186 | + if not token_id_list: |
| 187 | + print("No conversations to process.") |
| 188 | + return |
| 189 | + |
| 190 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 191 | + llm = LLM( |
| 192 | + model=args.model, |
| 193 | + speculative_config={ |
| 194 | + "method": "extract_hidden_states", |
| 195 | + "num_speculative_tokens": 1, |
| 196 | + "draft_model_config": { |
| 197 | + "hf_config": { |
| 198 | + "eagle_aux_hidden_state_layer_ids": all_capture_ids, |
| 199 | + } |
| 200 | + }, |
| 201 | + }, |
| 202 | + kv_transfer_config={ |
| 203 | + "kv_connector": "ExampleHiddenStatesConnector", |
| 204 | + "kv_role": "kv_producer", |
| 205 | + "kv_connector_extra_config": { |
| 206 | + "shared_storage_path": tmpdir, |
| 207 | + }, |
| 208 | + }, |
| 209 | + tensor_parallel_size=args.tp, |
| 210 | + trust_remote_code=args.trust_remote_code, |
| 211 | + ) |
| 212 | + |
| 213 | + sampling_params = SamplingParams(max_tokens=1) |
| 214 | + outputs = llm.generate( |
| 215 | + [{"prompt_token_ids": ids} for ids in token_id_list], |
| 216 | + sampling_params, |
| 217 | + ) |
| 218 | + |
| 219 | + num_success = 0 |
| 220 | + for output, conversation_id in tqdm( |
| 221 | + zip(outputs, conversation_ids), |
| 222 | + total=len(outputs), |
| 223 | + desc="Saving hidden states", |
| 224 | + ): |
| 225 | + hidden_states_path = output.kv_transfer_params.get("hidden_states_path") |
| 226 | + if hidden_states_path is None: |
| 227 | + print( |
| 228 | + f"Warning: no hidden_states_path for conversation {conversation_id}, skipping" |
| 229 | + ) |
| 230 | + continue |
| 231 | + |
| 232 | + with safe_open(hidden_states_path, framework="pt") as f: |
| 233 | + token_ids_tensor = f.get_tensor("token_ids") |
| 234 | + # Shape from vLLM: [seq_len, num_captured_layers, hidden_dim] |
| 235 | + hidden_states_tensor = f.get_tensor("hidden_states") |
| 236 | + |
| 237 | + # Last captured layer (= last model layer N-1) -> output hidden states |
| 238 | + # Earlier captured layers -> aux hidden states, concatenated along hidden dim |
| 239 | + output_hidden_states = hidden_states_tensor[:, -1, :] # [seq_len, hidden_dim] |
| 240 | + aux_hidden_states = hidden_states_tensor[:, :-1, :].reshape( |
| 241 | + hidden_states_tensor.shape[0], -1 |
| 242 | + ) # [seq_len, hidden_dim * num_aux_layers] |
| 243 | + |
| 244 | + # Align loss_mask with the token length returned by vLLM: if vLLM |
| 245 | + # truncated, truncate; if it somehow grew (shouldn't happen), pad with 1s |
| 246 | + # so that tail positions remain trainable under non-answer-only runs. |
| 247 | + vllm_seq_len = token_ids_tensor.shape[0] |
| 248 | + loss_mask = loss_masks_by_id[conversation_id] |
| 249 | + if loss_mask.shape[0] > vllm_seq_len: |
| 250 | + loss_mask = loss_mask[:vllm_seq_len] |
| 251 | + elif loss_mask.shape[0] < vllm_seq_len: |
| 252 | + pad = torch.ones(vllm_seq_len - loss_mask.shape[0], dtype=loss_mask.dtype) |
| 253 | + loss_mask = torch.cat([loss_mask, pad], dim=0) |
| 254 | + |
| 255 | + output_file = args.output_dir / f"{conversation_id}.pt" |
| 256 | + torch.save( |
| 257 | + { |
| 258 | + "input_ids": token_ids_tensor.to(torch.int64), |
| 259 | + "hidden_states": output_hidden_states, |
| 260 | + "aux_hidden_states": aux_hidden_states, |
| 261 | + "loss_mask": loss_mask, |
| 262 | + "conversation_id": conversation_id, |
| 263 | + }, |
| 264 | + output_file, |
| 265 | + ) |
| 266 | + num_success += 1 |
| 267 | + |
| 268 | + print(f"Successfully saved {num_success} / {len(token_id_list)} conversations.") |
| 269 | + |
| 270 | + |
| 271 | +if __name__ == "__main__": |
| 272 | + cli_args = parse_args() |
| 273 | + main(cli_args) |
0 commit comments