Skip to content

Commit 41ee25b

Browse files
committed
Create a script to run vllm inference
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 3f5859b commit 41ee25b

2 files changed

Lines changed: 247 additions & 0 deletions

File tree

examples/llm_ptq/example_utils.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
snapshot_download = None
4747

4848
import modelopt.torch.quantization as mtq
49+
from modelopt.torch.export.model_utils import MODEL_NAME_TO_TYPE
4950
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
5051
from modelopt.torch.utils.image_processor import (
5152
BaseImageProcessor,
@@ -66,6 +67,116 @@
6667

6768
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
6869

70+
# Files needed for tokenizer/processor that vLLM loads from model path
71+
TOKENIZER_FILES = [
72+
"vocab.json",
73+
"merges.txt",
74+
"tokenizer.json",
75+
"tokenizer_config.json",
76+
"special_tokens_map.json",
77+
"preprocessor_config.json",
78+
"chat_template.json",
79+
]
80+
81+
82+
def get_model_type_from_config(model_path: str) -> str | None:
83+
"""Get model type from the config.json file.
84+
85+
Args:
86+
model_path: Path to the model directory or HuggingFace model ID.
87+
88+
Returns:
89+
Model type string (e.g., 'qwen3omni', 'llama', 'gpt') or None if not found.
90+
"""
91+
config_path = os.path.join(model_path, "config.json")
92+
if not os.path.exists(config_path):
93+
return None
94+
95+
with open(config_path) as f:
96+
config = json.load(f)
97+
98+
# Check architectures field first
99+
architectures = config.get("architectures", [])
100+
for arch in architectures:
101+
for key, model_type in MODEL_NAME_TO_TYPE.items():
102+
if key.lower() in arch.lower():
103+
return model_type
104+
105+
# Fallback to model_type field
106+
model_type_field = config.get("model_type", "")
107+
for key, model_type in MODEL_NAME_TO_TYPE.items():
108+
if key.lower() in model_type_field.lower():
109+
return model_type
110+
111+
return None
112+
113+
114+
def get_sampling_params_from_config(model_path: str) -> dict:
115+
"""Extract sampling params from generation_config.json if present."""
116+
gen_config_path = Path(model_path) / "generation_config.json"
117+
if not gen_config_path.exists():
118+
return {}
119+
120+
gen_config = json.loads(gen_config_path.read_text())
121+
122+
params = {k: gen_config[k] for k in ("temperature", "top_p", "top_k") if k in gen_config}
123+
124+
for key in ("max_new_tokens", "max_length"):
125+
if key in gen_config:
126+
params["max_tokens"] = gen_config[key]
127+
break
128+
129+
return params
130+
131+
132+
def get_quantization_format(model_path: str) -> str | None:
133+
"""Get quantization format from the model config.
134+
135+
Args:
136+
model_path: Path to the model directory.
137+
138+
Returns:
139+
vLLM quantization string ('modelopt', 'modelopt_fp4') or None if not quantized.
140+
"""
141+
hf_quant_config_path = os.path.join(model_path, "hf_quant_config.json")
142+
if os.path.exists(hf_quant_config_path):
143+
with open(hf_quant_config_path) as f:
144+
quant_config = json.load(f)
145+
quant_algo = quant_config.get("quantization", {}).get("quant_algo", "")
146+
if "NVFP4" in quant_algo:
147+
return "modelopt_fp4"
148+
149+
return None
150+
151+
152+
def ensure_tokenizer_files(model_path: str, source_model_id: str) -> None:
153+
"""Copy tokenizer files from HF model to local quantized model dir if missing."""
154+
if not os.path.isdir(model_path):
155+
return # Not a local path, nothing to do
156+
157+
# Check if tokenizer files are missing
158+
missing_files = [f for f in TOKENIZER_FILES if not os.path.exists(os.path.join(model_path, f))]
159+
if not missing_files:
160+
return
161+
162+
if snapshot_download is None:
163+
print("Warning: huggingface_hub not installed, cannot download tokenizer files")
164+
return
165+
166+
print(f"Copying missing tokenizer files from {source_model_id}...")
167+
# Download only tokenizer files from HF
168+
cache_dir = snapshot_download(
169+
source_model_id,
170+
allow_patterns=TOKENIZER_FILES,
171+
)
172+
173+
for fname in TOKENIZER_FILES:
174+
src = os.path.join(cache_dir, fname)
175+
dst = os.path.join(model_path, fname)
176+
if os.path.exists(src) and not os.path.exists(dst):
177+
shutil.copy2(src, dst)
178+
print(f" Copied {fname}")
179+
69180

70181
def run_nemotron_vl_preview(
71182
full_model, tokenizer, input_ids, pyt_ckpt_path, stage_name, allow_fallback=False

examples/llm_ptq/run_vllm.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unified HF checkpoint inference with vLLM.
17+
18+
Usage:
19+
python run_vllm.py --model /path/to/quantized/model
20+
python run_vllm.py --model /path/to/model --tp 4
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import argparse
26+
27+
from example_utils import (
28+
ensure_tokenizer_files,
29+
get_model_type_from_config,
30+
get_quantization_format,
31+
get_sampling_params_from_config,
32+
)
33+
from transformers import AutoConfig, AutoProcessor
34+
from vllm import LLM, SamplingParams
35+
36+
37+
def main():
38+
parser = argparse.ArgumentParser(description="Run unified hf checkpoint inference with vLLM")
39+
parser.add_argument("--model", type=str, required=True, help="Model ID or path")
40+
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
41+
parser.add_argument(
42+
"--max-model-len",
43+
type=int,
44+
default=None,
45+
help="Max model length (auto-detected from config if not specified)",
46+
)
47+
parser.add_argument("--prompt", type=str, default="What in Nvidia?", help="Text prompt")
48+
parser.add_argument(
49+
"--tokenizer", type=str, default=None, help="Tokenizer ID or path (defaults to model path)"
50+
)
51+
parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature")
52+
parser.add_argument("--top-p", type=float, default=0.9, help="Top-p sampling")
53+
parser.add_argument("--top-k", type=int, default=-1, help="Top-k sampling (-1 to disable)")
54+
parser.add_argument("--max-tokens", type=int, default=512, help="Max tokens to generate")
55+
56+
args = parser.parse_args()
57+
58+
# Detect model type from config
59+
model_type = get_model_type_from_config(args.model)
60+
print(f"Detected model type: {model_type}")
61+
62+
# Detect quantization format
63+
quantization = get_quantization_format(args.model)
64+
print(f"Detected quantization: {quantization}")
65+
66+
# Get max_model_len from config if not specified
67+
if args.max_model_len is None:
68+
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True)
69+
args.max_model_len = getattr(config, "max_position_embeddings", 4096)
70+
print(f"Using max_model_len from config: {args.max_model_len}")
71+
72+
# Determine tokenizer source
73+
tokenizer_id = args.tokenizer or args.model
74+
75+
# Load processor for chat template
76+
processor = AutoProcessor.from_pretrained(tokenizer_id, trust_remote_code=True)
77+
78+
# Text-only conversations
79+
conversations = [
80+
[
81+
{
82+
"role": "user",
83+
"content": [{"type": "text", "text": args.prompt}],
84+
}
85+
],
86+
]
87+
88+
# Apply chat template
89+
apply_chat_kwargs = {
90+
"add_generation_prompt": True,
91+
"tokenize": False,
92+
}
93+
# Qwen3Omni-specific: disable thinking mode
94+
if model_type == "qwen3omni":
95+
apply_chat_kwargs["enable_thinking"] = False
96+
97+
texts = processor.apply_chat_template(conversations, **apply_chat_kwargs)
98+
99+
# Ensure tokenizer files exist in local model dir (vLLM loads processor from model path)
100+
if args.tokenizer:
101+
ensure_tokenizer_files(args.model, args.tokenizer)
102+
103+
print(f"Loading model: {args.model}")
104+
llm = LLM(
105+
model=args.model,
106+
tokenizer=tokenizer_id,
107+
tensor_parallel_size=args.tp,
108+
max_model_len=args.max_model_len,
109+
trust_remote_code=True,
110+
quantization=quantization,
111+
)
112+
113+
# Get sampling params from config, with CLI/defaults as fallback
114+
config_params = get_sampling_params_from_config(args.model)
115+
sampling_kwargs = {
116+
"temperature": config_params.get("temperature", args.temperature),
117+
"top_p": config_params.get("top_p", args.top_p),
118+
"max_tokens": config_params.get("max_tokens", args.max_tokens),
119+
}
120+
top_k = config_params.get("top_k", args.top_k)
121+
if top_k > 0:
122+
sampling_kwargs["top_k"] = top_k
123+
print(f"Sampling params: {sampling_kwargs}")
124+
sampling_params = SamplingParams(**sampling_kwargs)
125+
126+
print("Running inference...")
127+
outputs = llm.generate(texts, sampling_params)
128+
129+
for output in outputs:
130+
generated_text = output.outputs[0].text
131+
print("-" * 80)
132+
print(f"Generated: {generated_text}")
133+
134+
135+
if __name__ == "__main__":
136+
main()

0 commit comments

Comments
 (0)