Skip to content
11 changes: 8 additions & 3 deletions examples/llm_eval/run_lm_eval_vllm.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
# Script to run lm-evaluation-harness against a running vLLM OpenAI-compatible server.
#
# Usage:
# bash run_lm_eval_vllm.sh <model_name> [port] [task]
# bash run_lm_eval_vllm.sh <model_name> [port] [task] [host]
#
# Arguments:
# <model_name>: The name of the model being served (e.g., Qwen/Qwen3-30B-A3B). Used for the 'model' argument in lm_eval.
# [port]: The port the vLLM server is listening on (default: 8000).
# [task]: The lm_eval task(s) to run (default: mmlu).
# [host]: The IP address or hostname of the vLLM server (default: localhost).
#
# Example:
# # Start vLLM server first (in another terminal):
Expand All @@ -35,23 +36,27 @@
#
# # Run for a different task, e.g., hellaswag:
# bash run_lm_eval_vllm.sh Qwen/Qwen3-30B-A3B 8000 hellaswag
#
# # Run against a remote server:
# bash run_lm_eval_vllm.sh Qwen/Qwen3-30B-A3B 8000 mmlu 10.78.17.40
# ---

set -e
set -x

# --- Argument Parsing ---
if [ -z "$1" ]; then
echo "Usage: $0 <model_name> [port] [task]"
echo "Usage: $0 <model_name> [port] [task] [host]"
exit 1
fi
MODEL_NAME=$1
PORT=${2:-8000} # Default port is 8000 if not provided
TASK=${3:-mmlu} # Default task is mmlu if not provided
HOST=${4:-localhost} # Default host is localhost if not provided

# --- Environment Setup ---
export OPENAI_API_KEY="local" # Not strictly required for local, but good practice
BASE_URL="http://localhost:${PORT}/v1"
BASE_URL="http://${HOST}:${PORT}/v1"
COMPLETIONS_URL="${BASE_URL}/completions"

# --- Evaluation ---
Expand Down
262 changes: 247 additions & 15 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,134 @@
except ImportError:
snapshot_download = None

from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
from modelopt.torch.export.model_utils import match_model_type_by_name
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader
from modelopt.torch.utils.image_processor import (
BaseImageProcessor,
MllamaImageProcessor,
Qwen3OmniImageProcessor,
)
from modelopt.torch.utils.video_dataset_utils import (
Qwen3OmniVideoProcessor,
get_supported_video_datasets,
get_video_dataset_dataloader,
)
from modelopt.torch.utils.vlm_dataset_utils import (
get_supported_vlm_datasets,
get_vlm_dataset_dataloader,
)

logger = logging.getLogger(__name__)

SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]

# Files needed for tokenizer/processor that vLLM loads from model path
TOKENIZER_FILES = [
"vocab.json",
"merges.txt",
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"preprocessor_config.json",
"chat_template.json",
]


def get_model_type_from_config(model_path: str) -> str | None:
"""Get model type from the config.json file.

Args:
model_path: Path to the model directory or HuggingFace model ID.

Returns:
Model type string (e.g., 'qwen3omni', 'llama', 'gpt') or None if not found.
"""
config_path = os.path.join(model_path, "config.json")
if not os.path.exists(config_path):
return None

with open(config_path) as f:
config = json.load(f)

# Check architectures field first
for arch in config.get("architectures", []):
result = match_model_type_by_name(arch)
if result is not None:
return result

# Fallback to model_type field
return match_model_type_by_name(config.get("model_type", ""))


def get_sampling_params_from_config(model_path: str) -> dict:
"""Extract sampling params from generation_config.json if present."""
gen_config_path = Path(model_path) / "generation_config.json"
if not gen_config_path.exists():
return {}

gen_config = json.loads(gen_config_path.read_text())

params = {k: gen_config[k] for k in ("temperature", "top_p", "top_k") if k in gen_config}

for key in ("max_new_tokens", "max_length"):
if key in gen_config:
params["max_tokens"] = gen_config[key]
break

return params


def get_quantization_format(model_path: str) -> str | None:
"""Get quantization format from the model config.

Args:
model_path: Path to the model directory.

Returns:
vLLM quantization string ('modelopt', 'modelopt_fp4') or None if not quantized.
"""
hf_quant_config_path = os.path.join(model_path, "hf_quant_config.json")
if os.path.exists(hf_quant_config_path):
with open(hf_quant_config_path) as f:
quant_config = json.load(f)
quant_algo = quant_config.get("quantization", {}).get("quant_algo", "")
if "NVFP4" in quant_algo:
return "modelopt_fp4"

return None


def ensure_tokenizer_files(model_path: str, source_model_id: str) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was required by vLLM for running the model. If the tokenizer files are not saved, then we are unable to deploy the checkpoint with vLLM.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to export the tokenizer files with https://github.com/NVIDIA/Model-Optimizer/blob/main/examples/llm_ptq/hf_ptq.py#L691-L696

"""Copy tokenizer files from HF model to local quantized model dir if missing."""
if not os.path.isdir(model_path):
return # Not a local path, nothing to do

# Check if tokenizer files are missing
missing_files = [f for f in TOKENIZER_FILES if not os.path.exists(os.path.join(model_path, f))]
if not missing_files:
return

if snapshot_download is None:
print("Warning: huggingface_hub not installed, cannot download tokenizer files")
return

print(f"Copying missing tokenizer files from {source_model_id}...")
# Download only tokenizer files from HF
if os.path.isdir(source_model_id):
cache_dir = source_model_id
else:
cache_dir = snapshot_download(
source_model_id,
allow_patterns=TOKENIZER_FILES,
)

for fname in TOKENIZER_FILES:
src = os.path.join(cache_dir, fname)
dst = os.path.join(model_path, fname)
if os.path.exists(src) and not os.path.exists(dst):
shutil.copy2(src, dst)
print(f" Copied {fname}")


def run_nemotron_vl_preview(
full_model, tokenizer, input_ids, pyt_ckpt_path, stage_name, allow_fallback=False
Expand Down Expand Up @@ -241,9 +363,45 @@ def build_quant_cfg(
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}

if model_type in ["qwen3moe", "qwen3next"] and qformat == "nvfp4":
# Disable the attention projection layers to retain accuracy
quant_cfg["quant_cfg"]["model*.*attn*in_proj*"] = {"enable": False}
quant_cfg["quant_cfg"]["model*.*attn*q_proj*"] = {"enable": False}
quant_cfg["quant_cfg"]["model*.*attn*k_proj*"] = {"enable": False}
quant_cfg["quant_cfg"]["model*.*attn*v_proj*"] = {"enable": False}

if model_type == "deepseek":
# Disable MLA quantization for accuracy.
quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False}
quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False}

if model_type == "qwen3omni":
print(
"Disabling quantization for conv layers, audio tower and visual encoder in Qwen3Omni model"
)
quant_cfg["quant_cfg"]["*conv*"] = {"enable": False}
quant_cfg["quant_cfg"]["*audio_tower*"] = {"enable": False}
quant_cfg["quant_cfg"]["*visual*"] = {"enable": False}

return quant_cfg


def get_generation_kwargs(model_type: str) -> dict[str, Any]:
"""Get model-specific generation kwargs for calibration.

Args:
model_type: The model type string.

Returns:
Dictionary of generation kwargs for the model.
"""
generation_kwargs = {}
if model_type == "qwen3omni":
generation_kwargs["return_audio"] = False
generation_kwargs["thinker_max_new_tokens"] = 1
return generation_kwargs


def is_speculative(hf_config):
"""Check if the model architecture is a speculative model."""
return hf_config.architectures and any(
Expand Down Expand Up @@ -284,7 +442,7 @@ def get_processor(
if attn_implementation is not None:
model_kwargs["attn_implementation"] = attn_implementation

if model_type == "whisper":
if model_type in ("whisper", "mllama", "qwen3omni"):
processor = AutoProcessor.from_pretrained(
ckpt_path,
padding_side="left",
Expand All @@ -296,20 +454,11 @@ def get_processor(
f"Pad token for {ckpt_path} cannot be set!"
)

if model_type == "mllama":
return MllamaImageProcessor(processor, device)
elif model_type == "qwen3omni":
return Qwen3OmniImageProcessor(processor, device)
return processor
elif model_type == "mllama":
processor = AutoProcessor.from_pretrained(
ckpt_path,
padding_side="left",
**model_kwargs,
)
if processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
assert processor.tokenizer.pad_token is not None, (
f"Pad token for {ckpt_path} cannot be set!"
)

return MllamaImageProcessor(processor, device)
else:
# Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse)
try:
Expand Down Expand Up @@ -838,3 +987,86 @@ def copy_custom_model_files(source_path: str, export_path: str, trust_remote_cod
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
else:
print("No custom model files found to copy")


def get_qwen3omni_dataloader(
dataset_name: str | list[str] | None,
processor: Qwen3OmniImageProcessor | None,
tokenizer,
batch_size: int,
num_samples: int | list[int],
device: torch.device,
model_dtype: torch.dtype,
include_labels: bool = False,
):
"""Create a calibration dataloader for Qwen3Omni models.

Handles video, VLM, and text-only dataset configurations.

Args:
dataset_name: Name of the dataset(s) to use for calibration.
processor: The Qwen3OmniImageProcessor for multimodal inputs.
tokenizer: The tokenizer for text-only fallback.
batch_size: Batch size for the dataloader.
num_samples: Number of samples to use (int or list for multi-dataset).
device: Target device for tensors.
model_dtype: Model dtype for proper tensor conversion.
include_labels: Whether to include labels (for gradient-based auto_quantize).

Returns:
DataLoader for calibration.
"""
if dataset_name is None:
dataset_name = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
num_samples = [512, 512]

if processor is not None:
# Normalize single-element list to str for supported-dataset lookups
if isinstance(dataset_name, list) and len(dataset_name) == 1:
dataset_name = dataset_name[0]
if dataset_name in get_supported_video_datasets():
assert isinstance(dataset_name, str)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
video_processor = Qwen3OmniVideoProcessor(
processor.tokenizer if hasattr(processor, "tokenizer") else processor,
device=device,
dtype=model_dtype,
use_audio_in_video=True,
)
calib_dataloader = get_video_dataset_dataloader(
dataset_name=dataset_name,
processor=video_processor,
batch_size=batch_size,
num_samples=num_samples if isinstance(num_samples, int) else num_samples[0],
)
elif dataset_name in get_supported_vlm_datasets():
assert isinstance(dataset_name, str)
assert isinstance(processor, Qwen3OmniImageProcessor), (
"The Qwen3OmniImageProcessor must be set."
)
# Set dtype for proper tensor conversion in collate_function.
# Processor is created before model_dtype is known, so we set it here.
processor.dtype = model_dtype
calib_dataloader = get_vlm_dataset_dataloader(
dataset_name=dataset_name,
processor=processor,
batch_size=batch_size,
num_samples=num_samples if isinstance(num_samples, int) else num_samples[0],
)
else:
raise ValueError(
f"Dataset '{dataset_name}' not supported for Qwen3Omni with processor. "
f"Supported video datasets: {get_supported_video_datasets()}, "
f"Supported VLM datasets: {get_supported_vlm_datasets()}"
)
else:
# Text-only fallback
calib_dataloader = get_dataset_dataloader(
dataset_name=dataset_name if isinstance(dataset_name, list) else [dataset_name],
tokenizer=tokenizer,
batch_size=batch_size,
num_samples=num_samples if isinstance(num_samples, list) else [num_samples],
device=device,
include_labels=include_labels,
)

return calib_dataloader
Loading
Loading