Skip to content

Commit 2ca3e68

Browse files
feat: Add HuggingFace Hub model support for ORT and TensorRT inference benchmarks
- Add HuggingFaceModelLoader for downloading and caching models from HF Hub - Support both NLP (AutoModelForCausalLM) and vision (AutoModelForImageClassification) models - Add model_source and model_identifier parameters to TensorRT/ORT benchmarks - Add ONNX export pipeline for HuggingFace models with dynamic axes - Derive vision input shapes from ONNX graph dims with HF config fallback - Filter ONNX initializers from graph.input for correct NLP input handling - Add PyTorch 2.8+ compatibility (external_data vs use_external_data_format) - Add example script, unit tests, and config schema updates - Support HF_TOKEN env var for gated model access
1 parent 700d650 commit 2ca3e68

11 files changed

Lines changed: 1467 additions & 15 deletions

examples/benchmarks/ort_inference_performance.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,30 @@
44
"""Micro benchmark example for ONNXRuntime inference performance.
55
66
Commands to run:
7+
In-house models:
78
python3 examples/benchmarks/ort_inference_performance.py
9+
python3 examples/benchmarks/ort_inference_performance.py --model_source in-house
10+
11+
HuggingFace models:
12+
python3 examples/benchmarks/ort_inference_performance.py \
13+
--model_source huggingface --model_identifier bert-base-uncased
14+
python3 examples/benchmarks/ort_inference_performance.py \
15+
--model_source huggingface --model_identifier microsoft/resnet-50
16+
python3 examples/benchmarks/ort_inference_performance.py \
17+
--model_source huggingface --model_identifier deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
18+
19+
Environment variables:
20+
HF_TOKEN: HuggingFace token for gated models (optional)
821
"""
922

23+
import argparse
24+
1025
from superbench.benchmarks import BenchmarkRegistry, Platform
1126
from superbench.common.utils import logger
1227

13-
if __name__ == '__main__':
28+
29+
def run_inhouse_benchmark():
30+
"""Run ORT inference with in-house torchvision models."""
1431
context = BenchmarkRegistry.create_benchmark_context(
1532
'ort-inference', platform=Platform.CUDA, parameters='--pytorch_models resnet50 resnet101 --precision float16'
1633
)
@@ -21,3 +38,57 @@
2138
benchmark.name, benchmark.return_code, benchmark.result
2239
)
2340
)
41+
return benchmark
42+
43+
44+
def run_huggingface_benchmark(model_identifier, precision='float16', batch_size=32, seq_length=512):
45+
"""Run ORT inference with a HuggingFace model.
46+
47+
Args:
48+
model_identifier: HuggingFace model ID (e.g., 'bert-base-uncased').
49+
precision: Inference precision ('float32', 'float16', 'int8').
50+
batch_size: Batch size for inference.
51+
seq_length: Sequence length for transformer models.
52+
"""
53+
parameters = (
54+
f'--model_source huggingface '
55+
f'--model_identifier {model_identifier} '
56+
f'--precision {precision} '
57+
f'--batch_size {batch_size} '
58+
f'--seq_length {seq_length}'
59+
)
60+
61+
logger.info(f'Running ORT inference benchmark with HuggingFace model: {model_identifier}')
62+
63+
context = BenchmarkRegistry.create_benchmark_context('ort-inference', platform=Platform.CUDA, parameters=parameters)
64+
benchmark = BenchmarkRegistry.launch_benchmark(context)
65+
if benchmark:
66+
logger.info(
67+
'benchmark: {}, return code: {}, result: {}'.format(
68+
benchmark.name, benchmark.return_code, benchmark.result
69+
)
70+
)
71+
return benchmark
72+
73+
74+
if __name__ == '__main__':
75+
parser = argparse.ArgumentParser(description='ORT inference benchmark')
76+
parser.add_argument(
77+
'--model_source',
78+
type=str,
79+
default='in-house',
80+
choices=['in-house', 'huggingface'],
81+
help='Source of the model: in-house (default) or huggingface'
82+
)
83+
parser.add_argument(
84+
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
85+
)
86+
parser.add_argument('--precision', type=str, default='float16', choices=['float32', 'float16', 'int8'])
87+
parser.add_argument('--batch_size', type=int, default=32)
88+
parser.add_argument('--seq_length', type=int, default=512)
89+
args = parser.parse_args()
90+
91+
if args.model_source == 'huggingface':
92+
run_huggingface_benchmark(args.model_identifier, args.precision, args.batch_size, args.seq_length)
93+
else:
94+
run_inhouse_benchmark()

examples/benchmarks/tensorrt_inference_performance.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,30 @@
44
"""Micro benchmark example for TensorRT inference performance.
55
66
Commands to run:
7+
In-house models:
78
python3 examples/benchmarks/tensorrt_inference_performance.py
9+
python3 examples/benchmarks/tensorrt_inference_performance.py --model_source in-house
10+
11+
HuggingFace models:
12+
python3 examples/benchmarks/tensorrt_inference_performance.py \
13+
--model_source huggingface --model_identifier bert-base-uncased
14+
python3 examples/benchmarks/tensorrt_inference_performance.py \
15+
--model_source huggingface --model_identifier microsoft/resnet-50
16+
python3 examples/benchmarks/tensorrt_inference_performance.py \
17+
--model_source huggingface --model_identifier deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
18+
19+
Environment variables:
20+
HF_TOKEN: HuggingFace token for gated models (optional)
821
"""
922

23+
import argparse
24+
1025
from superbench.benchmarks import BenchmarkRegistry, Platform
1126
from superbench.common.utils import logger
1227

13-
if __name__ == '__main__':
28+
29+
def run_inhouse_benchmark():
30+
"""Run TensorRT inference with in-house torchvision models."""
1431
context = BenchmarkRegistry.create_benchmark_context('tensorrt-inference', platform=Platform.CUDA)
1532
benchmark = BenchmarkRegistry.launch_benchmark(context)
1633
if benchmark:
@@ -19,3 +36,64 @@
1936
benchmark.name, benchmark.return_code, benchmark.result
2037
)
2138
)
39+
return benchmark
40+
41+
42+
def run_huggingface_benchmark(model_identifier, precision='fp16', batch_size=32, seq_length=512, iterations=2048):
43+
"""Run TensorRT inference with a HuggingFace model.
44+
45+
Args:
46+
model_identifier: HuggingFace model ID (e.g., 'bert-base-uncased').
47+
precision: Inference precision ('fp32', 'fp16', 'int8').
48+
batch_size: Batch size for inference.
49+
seq_length: Sequence length for transformer models.
50+
iterations: Number of inference iterations.
51+
"""
52+
parameters = (
53+
f'--model_source huggingface '
54+
f'--model_identifier {model_identifier} '
55+
f'--precision {precision} '
56+
f'--batch_size {batch_size} '
57+
f'--seq_length {seq_length} '
58+
f'--iterations {iterations}'
59+
)
60+
61+
logger.info(f'Running TensorRT inference benchmark with HuggingFace model: {model_identifier}')
62+
63+
context = BenchmarkRegistry.create_benchmark_context(
64+
'tensorrt-inference', platform=Platform.CUDA, parameters=parameters
65+
)
66+
benchmark = BenchmarkRegistry.launch_benchmark(context)
67+
if benchmark:
68+
logger.info(
69+
'benchmark: {}, return code: {}, result: {}'.format(
70+
benchmark.name, benchmark.return_code, benchmark.result
71+
)
72+
)
73+
return benchmark
74+
75+
76+
if __name__ == '__main__':
77+
parser = argparse.ArgumentParser(description='TensorRT inference benchmark')
78+
parser.add_argument(
79+
'--model_source',
80+
type=str,
81+
default='in-house',
82+
choices=['in-house', 'huggingface'],
83+
help='Source of the model: in-house (default) or huggingface'
84+
)
85+
parser.add_argument(
86+
'--model_identifier', type=str, default='bert-base-uncased', help='HuggingFace model identifier'
87+
)
88+
parser.add_argument('--precision', type=str, default='fp16', choices=['fp32', 'fp16', 'int8'])
89+
parser.add_argument('--batch_size', type=int, default=32)
90+
parser.add_argument('--seq_length', type=int, default=512)
91+
parser.add_argument('--iterations', type=int, default=2048)
92+
args = parser.parse_args()
93+
94+
if args.model_source == 'huggingface':
95+
run_huggingface_benchmark(
96+
args.model_identifier, args.precision, args.batch_size, args.seq_length, args.iterations
97+
)
98+
else:
99+
run_inhouse_benchmark()

superbench/benchmarks/micro_benchmarks/_export_torch_to_onnx.py

Lines changed: 160 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,30 @@
33

44
"""Export PyTorch models to ONNX format."""
55

6+
import inspect
67
from pathlib import Path
78

89
from packaging import version
910
import torch.hub
1011
import torch.onnx
1112
import torchvision.models
12-
from transformers import BertConfig, GPT2Config, LlamaConfig
1313

14-
from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
15-
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
16-
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
17-
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
18-
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel
14+
import traceback
1915

20-
if MixtralBenchmarkModel is not None:
21-
from transformers import MixtralConfig
16+
from superbench.common.utils import logger
2217

2318

2419
class torch2onnxExporter():
2520
"""PyTorch model to ONNX exporter."""
2621
def __init__(self):
2722
"""Constructor."""
23+
from transformers import BertConfig, GPT2Config, LlamaConfig
24+
from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel
25+
from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel
26+
from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel
27+
from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel
28+
from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel
29+
2830
self.num_classes = 100
2931
self.lstm_input_size = 256
3032
self.benchmark_models = {
@@ -129,6 +131,7 @@ def __init__(self):
129131

130132
# Only include Mixtral models if MixtralBenchmarkModel is available
131133
if MixtralBenchmarkModel is not None:
134+
from transformers import MixtralConfig
132135
self.benchmark_models.update(
133136
{
134137
'mixtral-8x7b':
@@ -270,3 +273,152 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512):
270273
del dummy_input
271274
torch.cuda.empty_cache()
272275
return file_name
276+
277+
def export_huggingface_model(self, model, model_name, batch_size=1, seq_length=512, output_dir=None):
278+
"""Export a HuggingFace model to ONNX format.
279+
280+
Args:
281+
model: HuggingFace model instance to export.
282+
model_name (str): Name for the exported ONNX model file.
283+
batch_size (int): Batch size of input. Defaults to 1.
284+
seq_length (int): Sequence length of input. Defaults to 512.
285+
output_dir (str): Output directory path. If None, uses default path.
286+
287+
Returns:
288+
str: Exported ONNX model file path, or empty string if export fails.
289+
"""
290+
try:
291+
# Use custom output directory if provided
292+
output_path = Path(output_dir) if output_dir else self._onnx_model_path
293+
file_name = str(output_path / (model_name + '.onnx'))
294+
295+
# Put model in eval mode and move to CUDA if available
296+
model.eval()
297+
298+
# Disable cache to avoid DynamicCache issues with ONNX export
299+
if hasattr(model.config, 'use_cache'):
300+
model.config.use_cache = False
301+
302+
if torch.cuda.is_available():
303+
model = model.cuda()
304+
305+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
306+
307+
# Get model's dtype for inputs
308+
model_dtype = next(model.parameters()).dtype
309+
310+
# Detect model type and create appropriate inputs
311+
# Vision models use pixel_values, NLP models use input_ids
312+
# Use HuggingFace's main_input_name property for automatic detection
313+
main_input = getattr(model, 'main_input_name', 'input_ids')
314+
is_vision_model = main_input == 'pixel_values'
315+
316+
if is_vision_model:
317+
# Vision models: use pixel_values (batch_size, channels, height, width)
318+
# Derive C/H/W from model config rather than hard-coding 3x224x224
319+
num_channels = getattr(model.config, 'num_channels', 3)
320+
image_size = getattr(model.config, 'image_size', 224)
321+
if isinstance(image_size, (list, tuple)):
322+
img_h, img_w = image_size[0], image_size[1]
323+
else:
324+
img_h, img_w = image_size, image_size
325+
326+
dummy_input = torch.randn(batch_size, num_channels, img_h, img_w, dtype=model_dtype, device=device)
327+
input_names = ['pixel_values']
328+
dynamic_axes = {'pixel_values': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
329+
330+
# Wrapper for vision models
331+
class VisionModelWrapper(torch.nn.Module):
332+
def __init__(self, model):
333+
super().__init__()
334+
self.model = model
335+
336+
def forward(self, pixel_values):
337+
outputs = self.model(pixel_values=pixel_values)
338+
if hasattr(outputs, 'logits'):
339+
return outputs.logits
340+
elif hasattr(outputs, 'last_hidden_state'):
341+
return outputs.last_hidden_state
342+
else:
343+
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs
344+
345+
wrapped_model = VisionModelWrapper(model)
346+
export_args = (dummy_input, )
347+
else:
348+
# NLP models: use input_ids and attention_mask
349+
dummy_input = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
350+
attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device)
351+
input_names = ['input_ids', 'attention_mask']
352+
dynamic_axes = {
353+
'input_ids': {
354+
0: 'batch_size',
355+
1: 'seq_length'
356+
},
357+
'attention_mask': {
358+
0: 'batch_size',
359+
1: 'seq_length'
360+
},
361+
'output': {
362+
0: 'batch_size',
363+
1: 'seq_length'
364+
},
365+
}
366+
367+
# Wrapper for NLP models
368+
class NLPModelWrapper(torch.nn.Module):
369+
def __init__(self, model):
370+
super().__init__()
371+
self.model = model
372+
373+
def forward(self, input_ids, attention_mask):
374+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
375+
if hasattr(outputs, 'logits'):
376+
return outputs.logits
377+
elif hasattr(outputs, 'last_hidden_state'):
378+
return outputs.last_hidden_state
379+
else:
380+
return outputs[0] if isinstance(outputs, (tuple, list)) else outputs
381+
382+
wrapped_model = NLPModelWrapper(model)
383+
export_args = (dummy_input, attention_mask)
384+
385+
# Export to ONNX for large models (>2GB), use external data format
386+
model_size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3)
387+
use_external_data = model_size_gb > 2.0
388+
389+
if use_external_data:
390+
logger.info(f'Model size is {model_size_gb:.2f}GB, using external data format for ONNX export')
391+
392+
export_kwargs = {
393+
'opset_version': 14,
394+
'do_constant_folding': True,
395+
'input_names': input_names,
396+
'output_names': ['output'],
397+
'dynamic_axes': dynamic_axes,
398+
}
399+
if use_external_data:
400+
# PyTorch 2.8+ renamed 'use_external_data_format' to 'external_data'
401+
sig = inspect.signature(torch.onnx.export)
402+
if 'external_data' in sig.parameters:
403+
export_kwargs['external_data'] = True
404+
else:
405+
export_kwargs['use_external_data_format'] = True
406+
407+
torch.onnx.export(
408+
wrapped_model,
409+
export_args,
410+
file_name,
411+
**export_kwargs,
412+
)
413+
414+
# Clean up
415+
del dummy_input
416+
if torch.cuda.is_available():
417+
torch.cuda.empty_cache()
418+
419+
return file_name
420+
421+
except Exception as e:
422+
logger.error(f'Failed to export HuggingFace model to ONNX: {str(e)}')
423+
logger.error(traceback.format_exc())
424+
return ''

0 commit comments

Comments
 (0)