|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 |
|
| 17 | +import os |
| 18 | +import subprocess |
| 19 | + |
17 | 20 | import pytest |
18 | 21 | from _test_utils.examples.run_command import extend_cmd_parts, run_example_command |
19 | 22 |
|
| 23 | +# TODO: Add int4_awq once the INT4 exporter supports non-MatMul/Gemm consumer patterns |
| 24 | +# (e.g., DQ -> Reshape -> Slice in small ViT / SwinTransformer ONNX graphs). |
| 25 | +_QUANT_MODES = ["fp8", "int8", "mxfp8", "nvfp4", "auto"] |
| 26 | + |
| 27 | +_MODELS = { |
| 28 | + "vit_tiny": ("vit_tiny_patch16_224", '{"depth": 1}'), |
| 29 | + "swin_tiny": ("swin_tiny_patch4_window7_224", '{"depths": [1, 1, 1, 1]}'), |
| 30 | + "swinv2_tiny": ("swinv2_tiny_window8_256", '{"depths": [1, 1, 1, 1]}'), |
| 31 | +} |
| 32 | + |
| 33 | +# auto_quantize uses backward hooks for gradient-based scoring which conflicts with |
| 34 | +# SwinV2's inplace ReLU (PyTorch view+inplace autograd limitation). |
| 35 | +_SKIP = {("swinv2_tiny", "auto")} |
| 36 | + |
| 37 | +# Builder optimization level: 4 for low-bit modes, 3 otherwise |
| 38 | +_LOW_BIT_MODES = {"fp8", "int8", "nvfp4"} |
| 39 | + |
| 40 | + |
| 41 | +def _verify_trt_engine_build(onnx_save_path, quantize_mode): |
| 42 | + """Verify the exported ONNX model can be compiled into a TensorRT engine.""" |
| 43 | + example_dir = os.path.join( |
| 44 | + os.path.dirname(__file__), "..", "..", "..", "examples", "torch_onnx" |
| 45 | + ) |
| 46 | + onnx_path = os.path.join(example_dir, onnx_save_path) |
| 47 | + assert os.path.exists(onnx_path), f"ONNX file not found: {onnx_path}" |
20 | 48 |
|
21 | | -# TODO: Add accuracy evaluation after we upgrade TRT version to 10.12 |
22 | | -@pytest.mark.parametrize( |
23 | | - ("quantize_mode", "onnx_save_path", "calib_size", "num_score_steps"), |
24 | | - [ |
25 | | - ("fp8", "vit_base_patch16_224.fp8.onnx", "1", "1"), |
26 | | - ("int8", "vit_base_patch16_224.int8.onnx", "1", "1"), |
27 | | - ("nvfp4", "vit_base_patch16_224.nvfp4.onnx", "1", "1"), |
28 | | - ("mxfp8", "vit_base_patch16_224.mxfp8.onnx", "1", "1"), |
29 | | - ("int4_awq", "vit_base_patch16_224.int4_awq.onnx", "1", "1"), |
30 | | - ("auto", "vit_base_patch16_224.auto.onnx", "1", "1"), |
31 | | - ], |
32 | | -) |
33 | | -def test_torch_onnx(quantize_mode, onnx_save_path, calib_size, num_score_steps): |
| 49 | + opt_level = "4" if quantize_mode in _LOW_BIT_MODES else "3" |
| 50 | + cmd = [ |
| 51 | + "trtexec", |
| 52 | + f"--onnx={onnx_path}", |
| 53 | + "--stronglyTyped", |
| 54 | + f"--builderOptimizationLevel={opt_level}", |
| 55 | + ] |
| 56 | + |
| 57 | + result = subprocess.run(cmd, capture_output=True, text=True, timeout=600) # nosec |
| 58 | + assert result.returncode == 0, ( |
| 59 | + f"TensorRT engine build failed for {onnx_save_path} " |
| 60 | + f"(mode={quantize_mode}):\n{result.stdout}\n{result.stderr}" |
| 61 | + ) |
| 62 | + |
| 63 | + |
| 64 | +@pytest.mark.parametrize("quantize_mode", _QUANT_MODES) |
| 65 | +@pytest.mark.parametrize("model_key", list(_MODELS)) |
| 66 | +def test_torch_onnx(model_key, quantize_mode): |
| 67 | + if (model_key, quantize_mode) in _SKIP: |
| 68 | + pytest.skip(f"{model_key} + {quantize_mode} is not supported") |
| 69 | + timm_model_name, model_kwargs = _MODELS[model_key] |
| 70 | + onnx_save_path = f"{model_key}.{quantize_mode}.onnx" |
| 71 | + |
| 72 | + # Step 1: Quantize and export to ONNX |
34 | 73 | cmd_parts = extend_cmd_parts( |
35 | 74 | ["python", "torch_quant_to_onnx.py"], |
| 75 | + timm_model_name=timm_model_name, |
| 76 | + model_kwargs=model_kwargs, |
36 | 77 | quantize_mode=quantize_mode, |
37 | 78 | onnx_save_path=onnx_save_path, |
38 | | - calibration_data_size=calib_size, |
39 | | - num_score_steps=num_score_steps, |
| 79 | + calibration_data_size="1", |
| 80 | + num_score_steps="1", |
40 | 81 | ) |
| 82 | + cmd_parts.append("--no_pretrained") |
41 | 83 | run_example_command(cmd_parts, "torch_onnx") |
| 84 | + |
| 85 | + # Step 2: Verify the exported ONNX model builds a TensorRT engine |
| 86 | + _verify_trt_engine_build(onnx_save_path, quantize_mode) |
0 commit comments