|
38 | 38 | ORT_VERSION_FOR_OPSET_22 = version.parse("1.23.0") |
39 | 39 |
|
40 | 40 |
|
41 | | -@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"]) |
42 | | -def test_opset_below_minimum_upgrades_to_minimum(tmp_path, quant_mode): |
43 | | - """Test that specifying opset below minimum upgrades to minimum.""" |
44 | | - model_torch = SimpleMLP() |
45 | | - input_tensor = torch.randn(2, 16, 16) |
46 | | - |
47 | | - onnx_path = os.path.join(tmp_path, "model.onnx") |
48 | | - export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path) |
49 | | - |
50 | | - min_opset = MIN_OPSET[quant_mode] |
51 | | - |
52 | | - # Request opset below minimum |
53 | | - moq.quantize(onnx_path, quantize_mode=quant_mode, opset=min_opset - 1) |
54 | | - |
55 | | - # Verify output model was upgraded to minimum opset |
56 | | - output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") |
57 | | - output_model = onnx.load(output_onnx_path) |
58 | | - output_opset = get_opset_version(output_model) |
59 | | - |
60 | | - assert output_opset == min_opset, ( |
61 | | - f"Expected opset {min_opset} for {quant_mode}, got {output_opset}" |
62 | | - ) |
| 41 | +# Test scenarios: (scenario_name, export_opset_offset, request_opset_offset, expected_opset_offset) |
| 42 | +# Offsets are relative to MIN_OPSET[quant_mode]. |
| 43 | +OPSET_SCENARIOS = [ |
| 44 | + # Requesting opset below minimum should upgrade to minimum |
| 45 | + ("below_min_upgrades", -1, -1, 0), |
| 46 | + # Requesting opset below original model's opset (but above minimum) should preserve original |
| 47 | + ("below_original_preserves", 1, 0, 1), |
| 48 | + # Requesting opset above minimum should be respected |
| 49 | + ("above_min_respected", 0, 1, 1), |
| 50 | +] |
63 | 51 |
|
64 | 52 |
|
65 | 53 | @pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"]) |
66 | | -def test_opset_below_original_uses_original(tmp_path, quant_mode): |
67 | | - """Test that specifying opset below original model's opset uses original.""" |
68 | | - model_torch = SimpleMLP() |
69 | | - input_tensor = torch.randn(2, 16, 16) |
70 | | - |
| 54 | +@pytest.mark.parametrize( |
| 55 | + ("scenario_name", "export_opset_offset", "request_opset_offset", "expected_opset_offset"), |
| 56 | + OPSET_SCENARIOS, |
| 57 | + ids=[s[0] for s in OPSET_SCENARIOS], |
| 58 | +) |
| 59 | +def test_quantize_opset_handling( |
| 60 | + tmp_path, |
| 61 | + quant_mode, |
| 62 | + scenario_name, |
| 63 | + export_opset_offset, |
| 64 | + request_opset_offset, |
| 65 | + expected_opset_offset, |
| 66 | +): |
| 67 | + """Test opset handling in quantization API. |
| 68 | +
|
| 69 | + Scenarios: |
| 70 | + - below_min_upgrades: Requesting opset below minimum upgrades to minimum. |
| 71 | + - below_original_preserves: Requesting opset below original model's opset preserves original. |
| 72 | + - above_min_respected: Requesting opset at or above minimum is respected. |
| 73 | + """ |
71 | 74 | min_opset = MIN_OPSET[quant_mode] |
72 | | - higher_opset = min_opset + 1 |
73 | 75 |
|
74 | | - # Skip if required opset exceeds onnxruntime support |
75 | | - ort_version = version.parse(onnxruntime.__version__) |
76 | | - if higher_opset >= 22 and ort_version < ORT_VERSION_FOR_OPSET_22: |
77 | | - pytest.skip( |
78 | | - f"Opset {higher_opset} requires onnxruntime >= {ORT_VERSION_FOR_OPSET_22}, have {ort_version}" |
79 | | - ) |
80 | | - |
81 | | - onnx_path = os.path.join(tmp_path, "model.onnx") |
82 | | - export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path, opset=higher_opset) |
83 | | - |
84 | | - # Verify the exported model has the higher opset |
85 | | - original_model = onnx.load(onnx_path) |
86 | | - assert get_opset_version(original_model) == higher_opset |
87 | | - |
88 | | - # Request opset below original (but above minimum) |
89 | | - moq.quantize(onnx_path, quantize_mode=quant_mode, opset=min_opset) |
90 | | - |
91 | | - # Verify output model preserves the higher original opset |
92 | | - output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") |
93 | | - output_model = onnx.load(output_onnx_path) |
94 | | - output_opset = get_opset_version(output_model) |
| 76 | + # Calculate actual opset values from offsets |
| 77 | + export_opset = min_opset + export_opset_offset |
| 78 | + request_opset = min_opset + request_opset_offset |
| 79 | + expected_opset = min_opset + expected_opset_offset |
95 | 80 |
|
96 | | - assert output_opset == higher_opset, ( |
97 | | - f"Expected original opset {higher_opset} to be preserved, got {output_opset}" |
98 | | - ) |
99 | | - |
100 | | - |
101 | | -@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"]) |
102 | | -def test_opset_above_minimum(tmp_path, quant_mode): |
103 | | - """Test that specifying opset at or above minimum is respected.""" |
| 81 | + # Skip if required opset exceeds onnxruntime support |
| 82 | + max_opset = max(export_opset, request_opset, expected_opset) |
| 83 | + if max_opset >= 22: |
| 84 | + ort_version = version.parse(onnxruntime.__version__) |
| 85 | + if ort_version < ORT_VERSION_FOR_OPSET_22: |
| 86 | + pytest.skip( |
| 87 | + f"Opset {max_opset} requires onnxruntime >= {ORT_VERSION_FOR_OPSET_22}, have {ort_version}" |
| 88 | + ) |
| 89 | + |
| 90 | + # Setup: create and export model |
104 | 91 | model_torch = SimpleMLP() |
105 | 92 | input_tensor = torch.randn(2, 16, 16) |
106 | | - |
107 | | - min_opset = MIN_OPSET[quant_mode] |
108 | | - target_opset = min_opset + 1 |
109 | | - |
110 | 93 | onnx_path = os.path.join(tmp_path, "model.onnx") |
111 | | - export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path) |
| 94 | + export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path, opset=export_opset) |
112 | 95 |
|
113 | | - moq.quantize(onnx_path, quantize_mode=quant_mode, opset=target_opset) |
| 96 | + # Run quantization |
| 97 | + moq.quantize(onnx_path, quantize_mode=quant_mode, opset=request_opset) |
114 | 98 |
|
115 | | - # Verify output model has the requested opset |
| 99 | + # Verify output opset |
116 | 100 | output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx") |
117 | 101 | output_model = onnx.load(output_onnx_path) |
118 | 102 | output_opset = get_opset_version(output_model) |
119 | 103 |
|
120 | | - assert output_opset == target_opset, ( |
121 | | - f"Expected opset {target_opset} for {quant_mode}, got {output_opset}" |
| 104 | + assert output_opset == expected_opset, ( |
| 105 | + f"[{scenario_name}] Expected opset {expected_opset} for {quant_mode}, got {output_opset}" |
122 | 106 | ) |
0 commit comments