|
9 | 9 | import torch.hub |
10 | 10 | import torch.onnx |
11 | 11 | import torchvision.models |
12 | | -from transformers import BertConfig, GPT2Config, LlamaConfig |
13 | 12 |
|
14 | | -from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel |
15 | | -from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel |
16 | | -from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel |
17 | | -from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel |
18 | | -from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel |
| 13 | +import traceback |
19 | 14 |
|
20 | | -if MixtralBenchmarkModel is not None: |
21 | | - from transformers import MixtralConfig |
| 15 | +from superbench.common.utils import logger |
22 | 16 |
|
23 | 17 |
|
24 | 18 | class torch2onnxExporter(): |
25 | 19 | """PyTorch model to ONNX exporter.""" |
26 | 20 | def __init__(self): |
27 | 21 | """Constructor.""" |
| 22 | + from transformers import BertConfig, GPT2Config, LlamaConfig |
| 23 | + from superbench.benchmarks.model_benchmarks.pytorch_bert import BertBenchmarkModel |
| 24 | + from superbench.benchmarks.model_benchmarks.pytorch_gpt2 import GPT2BenchmarkModel |
| 25 | + from superbench.benchmarks.model_benchmarks.pytorch_lstm import LSTMBenchmarkModel |
| 26 | + from superbench.benchmarks.model_benchmarks.pytorch_llama import LlamaBenchmarkModel |
| 27 | + from superbench.benchmarks.model_benchmarks.pytorch_mixtral import MixtralBenchmarkModel |
| 28 | + |
28 | 29 | self.num_classes = 100 |
29 | 30 | self.lstm_input_size = 256 |
30 | 31 | self.benchmark_models = { |
@@ -129,6 +130,7 @@ def __init__(self): |
129 | 130 |
|
130 | 131 | # Only include Mixtral models if MixtralBenchmarkModel is available |
131 | 132 | if MixtralBenchmarkModel is not None: |
| 133 | + from transformers import MixtralConfig |
132 | 134 | self.benchmark_models.update( |
133 | 135 | { |
134 | 136 | 'mixtral-8x7b': |
@@ -270,3 +272,151 @@ def export_benchmark_model(self, model_name, batch_size=1, seq_length=512): |
270 | 272 | del dummy_input |
271 | 273 | torch.cuda.empty_cache() |
272 | 274 | return file_name |
| 275 | + |
| 276 | + def export_huggingface_model(self, model, model_name, batch_size=1, seq_length=512, output_dir=None): |
| 277 | + """Export a HuggingFace model to ONNX format. |
| 278 | +
|
| 279 | + Args: |
| 280 | + model: HuggingFace model instance to export. |
| 281 | + model_name (str): Name for the exported ONNX model file. |
| 282 | + batch_size (int): Batch size of input. Defaults to 1. |
| 283 | + seq_length (int): Sequence length of input. Defaults to 512. |
| 284 | + output_dir (str): Output directory path. If None, uses default path. |
| 285 | +
|
| 286 | + Returns: |
| 287 | + str: Exported ONNX model file path, or empty string if export fails. |
| 288 | + """ |
| 289 | + try: |
| 290 | + # Use custom output directory if provided |
| 291 | + output_path = Path(output_dir) if output_dir else self._onnx_model_path |
| 292 | + file_name = str(output_path / (model_name + '.onnx')) |
| 293 | + |
| 294 | + # Put model in eval mode and move to CUDA if available |
| 295 | + model.eval() |
| 296 | + |
| 297 | + # Disable cache to avoid DynamicCache issues with ONNX export |
| 298 | + if hasattr(model.config, 'use_cache'): |
| 299 | + model.config.use_cache = False |
| 300 | + |
| 301 | + if torch.cuda.is_available(): |
| 302 | + model = model.cuda() |
| 303 | + |
| 304 | + device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 305 | + |
| 306 | + # Get model's dtype for inputs |
| 307 | + model_dtype = next(model.parameters()).dtype |
| 308 | + |
| 309 | + # Detect model type and create appropriate inputs |
| 310 | + # Vision models use pixel_values, NLP models use input_ids |
| 311 | + # Use HuggingFace's main_input_name property for automatic detection |
| 312 | + main_input = getattr(model, 'main_input_name', 'input_ids') |
| 313 | + is_vision_model = main_input == 'pixel_values' |
| 314 | + |
| 315 | + if is_vision_model: |
| 316 | + # Vision models: use pixel_values (batch_size, channels, height, width) |
| 317 | + # Standard ImageNet size is 224x224, 3 channels |
| 318 | + # Match the dtype of the model |
| 319 | + dummy_input = torch.randn(batch_size, 3, 224, 224, dtype=model_dtype, device=device) |
| 320 | + input_names = ['pixel_values'] |
| 321 | + dynamic_axes = {'pixel_values': {0: 'batch_size'}, 'output': {0: 'batch_size'}} |
| 322 | + |
| 323 | + # Wrapper for vision models |
| 324 | + class VisionModelWrapper(torch.nn.Module): |
| 325 | + def __init__(self, model): |
| 326 | + super().__init__() |
| 327 | + self.model = model |
| 328 | + |
| 329 | + def forward(self, pixel_values): |
| 330 | + outputs = self.model(pixel_values=pixel_values) |
| 331 | + if hasattr(outputs, 'logits'): |
| 332 | + return outputs.logits |
| 333 | + elif hasattr(outputs, 'last_hidden_state'): |
| 334 | + return outputs.last_hidden_state |
| 335 | + else: |
| 336 | + return outputs[0] if isinstance(outputs, (tuple, list)) else outputs |
| 337 | + |
| 338 | + wrapped_model = VisionModelWrapper(model) |
| 339 | + export_args = (dummy_input, ) |
| 340 | + else: |
| 341 | + # NLP models: use input_ids and attention_mask |
| 342 | + dummy_input = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device) |
| 343 | + attention_mask = torch.ones((batch_size, seq_length), dtype=torch.int64, device=device) |
| 344 | + input_names = ['input_ids', 'attention_mask'] |
| 345 | + dynamic_axes = { |
| 346 | + 'input_ids': { |
| 347 | + 0: 'batch_size', |
| 348 | + 1: 'seq_length' |
| 349 | + }, |
| 350 | + 'attention_mask': { |
| 351 | + 0: 'batch_size', |
| 352 | + 1: 'seq_length' |
| 353 | + }, |
| 354 | + 'output': { |
| 355 | + 0: 'batch_size' |
| 356 | + }, |
| 357 | + } |
| 358 | + |
| 359 | + # Wrapper for NLP models |
| 360 | + class NLPModelWrapper(torch.nn.Module): |
| 361 | + def __init__(self, model): |
| 362 | + super().__init__() |
| 363 | + self.model = model |
| 364 | + |
| 365 | + def forward(self, input_ids, attention_mask): |
| 366 | + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) |
| 367 | + if hasattr(outputs, 'logits'): |
| 368 | + return outputs.logits |
| 369 | + elif hasattr(outputs, 'last_hidden_state'): |
| 370 | + return outputs.last_hidden_state |
| 371 | + else: |
| 372 | + return outputs[0] if isinstance(outputs, (tuple, list)) else outputs |
| 373 | + |
| 374 | + wrapped_model = NLPModelWrapper(model) |
| 375 | + export_args = (dummy_input, attention_mask) |
| 376 | + |
| 377 | + # Export to ONNX for large models (>2GB), use external data format |
| 378 | + model_size_gb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**3) |
| 379 | + use_external_data = model_size_gb > 2.0 |
| 380 | + |
| 381 | + if use_external_data: |
| 382 | + logger.info(f'Model size is {model_size_gb:.2f}GB, using external data format for ONNX export') |
| 383 | + |
| 384 | + torch.onnx.export( |
| 385 | + wrapped_model, |
| 386 | + export_args, |
| 387 | + file_name, |
| 388 | + opset_version=14, |
| 389 | + do_constant_folding=True, |
| 390 | + input_names=input_names, |
| 391 | + output_names=['output'], |
| 392 | + dynamic_axes=dynamic_axes, |
| 393 | + ) |
| 394 | + |
| 395 | + # If using external data, convert to external data format |
| 396 | + if use_external_data: |
| 397 | + import onnx |
| 398 | + from onnx.external_data_helper import convert_model_to_external_data |
| 399 | + |
| 400 | + onnx_model = onnx.load(file_name) |
| 401 | + external_data_path = model_name + '_data.bin' |
| 402 | + convert_model_to_external_data( |
| 403 | + onnx_model, |
| 404 | + all_tensors_to_one_file=True, |
| 405 | + location=external_data_path, |
| 406 | + size_threshold=1024, |
| 407 | + convert_attribute=False |
| 408 | + ) |
| 409 | + onnx.save(onnx_model, file_name) |
| 410 | + logger.info(f'Converted ONNX model to external data format: {external_data_path}') |
| 411 | + |
| 412 | + # Clean up |
| 413 | + del dummy_input |
| 414 | + if torch.cuda.is_available(): |
| 415 | + torch.cuda.empty_cache() |
| 416 | + |
| 417 | + return file_name |
| 418 | + |
| 419 | + except Exception as e: |
| 420 | + logger.error(f'Failed to export HuggingFace model to ONNX: {str(e)}') |
| 421 | + logger.error(traceback.format_exc()) |
| 422 | + return '' |
0 commit comments