|
3 | 3 |
|
4 | 4 | """Export PyTorch models to ONNX format.""" |
5 | 5 |
|
| 6 | +import inspect |
6 | 7 | from pathlib import Path |
7 | 8 |
|
8 | 9 | from packaging import version |
9 | 10 | import torch.hub |
10 | 11 | import torch.onnx |
11 | 12 | import torchvision.models |
12 | | -from transformers import BertConfig, GPT2Config, LlamaConfig |
13 | 13 |
|
14 | | -from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel |
15 | | -from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel |
16 | | -from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel |
17 | | -from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel |
18 | | -from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel |
| 14 | +import traceback |
19 | 15 |
|
20 | | -if MixtralBenchmarkModel is not None: |
21 | | - from transformers import MixtralConfig |
| 16 | +from superbench.common.utils import logger |
22 | 17 |
|
23 | 18 |
|
24 | 19 | class torch2onnxExporter(): |
25 | 20 | """PyTorch model to ONNX exporter.""" |
26 | 21 | def __init__(self): |
27 | 22 | """Constructor.""" |
| 23 | + from transformers import BertConfig, GPT2Config, LlamaConfig |
| 24 | + from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel |
| 25 | + from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel |
| 26 | + from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel |
| 27 | + from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel |
| 28 | + from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel |
| 29 | + |
28 | 30 | self.num_classes = 100 |
29 | 31 | self.lstm_input_size = 256 |
30 | 32 | self.benchmark_models = { |
@@ -129,6 +131,7 @@ def __init__(self): |
129 | 131 |
|
130 | 132 | # Only include Mixtral models if MixtralBenchmarkModel is available |
131 | 133 | if MixtralBenchmarkModel is not None: |
| 134 | + from transformers import MixtralConfig |
132 | 135 | self.benchmark_models.update( |
133 | 136 | { |
134 | 137 | 'mixtral-8x7b': |
@@ -270,3 +273,152 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512): |
270 | 273 | del dummy_input |
271 | 274 | torch.cuda.empty_cache() |
272 | 275 | return file_name |
| 276 | + |
| 277 | + def export_huggingface_model(self, model, model_name, batch_size=1, seq_length=512, output_dir=None): |
| 278 | + """Export a HuggingFace model to ONNX format. |
| 279 | +
|
| 280 | + Args: |
| 281 | + model: HuggingFace model instance to export. |
| 282 | + model_name (str): Name for the exported ONNX model file. |
| 283 | + batch_size (int): Batch size of input. Defaults to 1. |
| 284 | + seq_length (int): Sequence length of input. Defaults to 512. |
| 285 | + output_dir (str): Output directory path. If None, uses default path. |
| 286 | +
|
| 287 | + Returns: |
| 288 | + str: Exported ONNX model file path, or empty string if export fails. |
| 289 | + """ |
| 290 | + try: |
| 291 | + # Use custom output directory if provided |
| 292 | + output_path = Path(output_dir) if output_dir else self._onnx_model_path |
| 293 | + file_name = str(output_path / (model_name + '.onnx')) |
| 294 | + |
| 295 | + # Put model in eval mode and move to CUDA if available |
| 296 | + model.eval() |
| 297 | + |
| 298 | + # Disable cache to avoid DynamicCache issues with ONNX export |
| 299 | + if hasattr(model.config, 'use_cache'): |
| 300 | + model.config.use_cache = False |
| 301 | + |
| 302 | + if torch.cuda.is_available(): |
| 303 | + model = model.cuda() |
| 304 | + |
| 305 | + device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 306 | + |
| 307 | + # Get model's dtype for inputs |
| 308 | + model_dtype = next(model.parameters()).dtype |
| 309 | + |
| 310 | + # Detect model type and create appropriate inputs |
| 311 | + # Vision models use pixel_values, NLP models use input_ids |
| 312 | + # Use HuggingFace's main_input_name property for automatic detection |
| 313 | + main_input = getattr(model, 'main_input_name', 'input_ids') |
| 314 | + is_vision_model = main_input == 'pixel_values' |
| 315 | + |
| 316 | + if is_vision_model: |
| 317 | + # Vision models: use pixel_values (batch_size, channels, height, width) |
| 318 | + # Derive C/H/W from model config rather than hard-coding 3x224x224 |
| 319 | + num_channels = getattr(model.config, 'num_channels', 3) |
| 320 | + image_size = getattr(model.config, 'image_size', 224) |
| 321 | + if isinstance(image_size, (list, tuple)): |
| 322 | + img_h, img_w = image_size[0], image_size[1] |
| 323 | + else: |
| 324 | + img_h, img_w = image_size, image_size |
| 325 | + |
| 326 | + dummy_input = torch.randn(batch_size, num_channels, img_h, img_w, dtype=model_dtype, device=device) |
| 327 | + input_names = ['pixel_values'] |
| 328 | + dynamic_axes = {'pixel_values': {0: 'batch_size'}, 'output': {0: 'batch_size'}} |
| 329 | + |
| 330 | + # Wrapper for vision models |
| 331 | + class VisionModelWrapper(torch.nn.Module): |
| 332 | + def __init__(self, model): |
| 333 | + super().__init__() |
| 334 | + self.model = model |
| 335 | + |
| 336 | + def forward(self, pixel_values): |
| 337 | + outputs = self.model(pixel_values=pixel_values) |
| 338 | + if hasattr(outputs, 'logits'): |
| 339 | + return outputs.logits |
| 340 | + elif hasattr(outputs, 'last_hidden_state'): |
| 341 | + return outputs.last_hidden_state |
| 342 | + else: |
| 343 | + return outputs[0] if isinstance(outputs, (tuple, list)) else outputs |
| 344 | + |
| 345 | + wrapped_model = VisionModelWrapper(model) |
| 346 | + export_args = (dummy_input, ) |
| 347 | + else: |
| 348 | + # NLP models: use input_ids and attention_mask |
| 349 | + dummy_input = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device) |
| 350 | + attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device) |
| 351 | + input_names = ['input_ids', 'attention_mask'] |
| 352 | + dynamic_axes = { |
| 353 | + 'input_ids': { |
| 354 | + 0: 'batch_size', |
| 355 | + 1: 'seq_length' |
| 356 | + }, |
| 357 | + 'attention_mask': { |
| 358 | + 0: 'batch_size', |
| 359 | + 1: 'seq_length' |
| 360 | + }, |
| 361 | + 'output': { |
| 362 | + 0: 'batch_size', |
| 363 | + 1: 'seq_length' |
| 364 | + }, |
| 365 | + } |
| 366 | + |
| 367 | + # Wrapper for NLP models |
| 368 | + class NLPModelWrapper(torch.nn.Module): |
| 369 | + def __init__(self, model): |
| 370 | + super().__init__() |
| 371 | + self.model = model |
| 372 | + |
| 373 | + def forward(self, input_ids, attention_mask): |
| 374 | + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
| 375 | + if hasattr(outputs, 'logits'): |
| 376 | + return outputs.logits |
| 377 | + elif hasattr(outputs, 'last_hidden_state'): |
| 378 | + return outputs.last_hidden_state |
| 379 | + else: |
| 380 | + return outputs[0] if isinstance(outputs, (tuple, list)) else outputs |
| 381 | + |
| 382 | + wrapped_model = NLPModelWrapper(model) |
| 383 | + export_args = (dummy_input, attention_mask) |
| 384 | + |
| 385 | + # Export to ONNX for large models (>2GB), use external data format |
| 386 | + model_size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3) |
| 387 | + use_external_data = model_size_gb > 2.0 |
| 388 | + |
| 389 | + if use_external_data: |
| 390 | + logger.info(f'Model size is {model_size_gb:.2f}GB, using external data format for ONNX export') |
| 391 | + |
| 392 | + export_kwargs = { |
| 393 | + 'opset_version': 14, |
| 394 | + 'do_constant_folding': True, |
| 395 | + 'input_names': input_names, |
| 396 | + 'output_names': ['output'], |
| 397 | + 'dynamic_axes': dynamic_axes, |
| 398 | + } |
| 399 | + if use_external_data: |
| 400 | + # PyTorch 2.8+ renamed 'use_external_data_format' to 'external_data' |
| 401 | + sig = inspect.signature(torch.onnx.export) |
| 402 | + if 'external_data' in sig.parameters: |
| 403 | + export_kwargs['external_data'] = True |
| 404 | + else: |
| 405 | + export_kwargs['use_external_data_format'] = True |
| 406 | + |
| 407 | + torch.onnx.export( |
| 408 | + wrapped_model, |
| 409 | + export_args, |
| 410 | + file_name, |
| 411 | + **export_kwargs, |
| 412 | + ) |
| 413 | + |
| 414 | + # Clean up |
| 415 | + del dummy_input |
| 416 | + if torch.cuda.is_available(): |
| 417 | + torch.cuda.empty_cache() |
| 418 | + |
| 419 | + return file_name |
| 420 | + |
| 421 | + except Exception as e: |
| 422 | + logger.error(f'Failed to export HuggingFace model to ONNX: {str(e)}') |
| 423 | + logger.error(traceback.format_exc()) |
| 424 | + return '' |
0 commit comments