Skip to content

Commit 94e574b

Browse files
committed
refactor test
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
1 parent 4f3ee41 commit 94e574b

File tree

2 files changed

+51
-69
lines changed

2 files changed

+51
-69
lines changed

modelopt/onnx/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,7 @@ def get_opset_version(model: onnx.ModelProto) -> int:
700700

701701

702702
def check_model_uses_external_data(model: onnx.ModelProto) -> bool:
703-
"""Checks if the model uses external data.
704-
True if any initializer tensor has data_location set to EXTERNAL.
705-
"""
703+
"""Checks if the model uses external data. True if any initializer tensor has data_location set to EXTERNAL."""
706704
return any(
707705
init.HasField("data_location") and init.data_location == onnx.TensorProto.EXTERNAL
708706
for init in model.graph.initializer

tests/unit/onnx/test_quantize_api.py

Lines changed: 50 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -38,85 +38,69 @@
3838
ORT_VERSION_FOR_OPSET_22 = version.parse("1.23.0")
3939

4040

41-
@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"])
42-
def test_opset_below_minimum_upgrades_to_minimum(tmp_path, quant_mode):
43-
"""Test that specifying opset below minimum upgrades to minimum."""
44-
model_torch = SimpleMLP()
45-
input_tensor = torch.randn(2, 16, 16)
46-
47-
onnx_path = os.path.join(tmp_path, "model.onnx")
48-
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path)
49-
50-
min_opset = MIN_OPSET[quant_mode]
51-
52-
# Request opset below minimum
53-
moq.quantize(onnx_path, quantize_mode=quant_mode, opset=min_opset - 1)
54-
55-
# Verify output model was upgraded to minimum opset
56-
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
57-
output_model = onnx.load(output_onnx_path)
58-
output_opset = get_opset_version(output_model)
59-
60-
assert output_opset == min_opset, (
61-
f"Expected opset {min_opset} for {quant_mode}, got {output_opset}"
62-
)
41+
# Test scenarios: (scenario_name, export_opset_offset, request_opset_offset, expected_opset_offset)
42+
# Offsets are relative to MIN_OPSET[quant_mode].
43+
OPSET_SCENARIOS = [
44+
# Requesting opset below minimum should upgrade to minimum
45+
("below_min_upgrades", -1, -1, 0),
46+
# Requesting opset below original model's opset (but above minimum) should preserve original
47+
("below_original_preserves", 1, 0, 1),
48+
# Requesting opset above minimum should be respected
49+
("above_min_respected", 0, 1, 1),
50+
]
6351

6452

6553
@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"])
66-
def test_opset_below_original_uses_original(tmp_path, quant_mode):
67-
"""Test that specifying opset below original model's opset uses original."""
68-
model_torch = SimpleMLP()
69-
input_tensor = torch.randn(2, 16, 16)
70-
54+
@pytest.mark.parametrize(
55+
("scenario_name", "export_opset_offset", "request_opset_offset", "expected_opset_offset"),
56+
OPSET_SCENARIOS,
57+
ids=[s[0] for s in OPSET_SCENARIOS],
58+
)
59+
def test_quantize_opset_handling(
60+
tmp_path,
61+
quant_mode,
62+
scenario_name,
63+
export_opset_offset,
64+
request_opset_offset,
65+
expected_opset_offset,
66+
):
67+
"""Test opset handling in quantization API.
68+
69+
Scenarios:
70+
- below_min_upgrades: Requesting opset below minimum upgrades to minimum.
71+
- below_original_preserves: Requesting opset below original model's opset preserves original.
72+
- above_min_respected: Requesting opset at or above minimum is respected.
73+
"""
7174
min_opset = MIN_OPSET[quant_mode]
72-
higher_opset = min_opset + 1
7375

74-
# Skip if required opset exceeds onnxruntime support
75-
ort_version = version.parse(onnxruntime.__version__)
76-
if higher_opset >= 22 and ort_version < ORT_VERSION_FOR_OPSET_22:
77-
pytest.skip(
78-
f"Opset {higher_opset} requires onnxruntime >= {ORT_VERSION_FOR_OPSET_22}, have {ort_version}"
79-
)
80-
81-
onnx_path = os.path.join(tmp_path, "model.onnx")
82-
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path, opset=higher_opset)
83-
84-
# Verify the exported model has the higher opset
85-
original_model = onnx.load(onnx_path)
86-
assert get_opset_version(original_model) == higher_opset
87-
88-
# Request opset below original (but above minimum)
89-
moq.quantize(onnx_path, quantize_mode=quant_mode, opset=min_opset)
90-
91-
# Verify output model preserves the higher original opset
92-
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
93-
output_model = onnx.load(output_onnx_path)
94-
output_opset = get_opset_version(output_model)
76+
# Calculate actual opset values from offsets
77+
export_opset = min_opset + export_opset_offset
78+
request_opset = min_opset + request_opset_offset
79+
expected_opset = min_opset + expected_opset_offset
9580

96-
assert output_opset == higher_opset, (
97-
f"Expected original opset {higher_opset} to be preserved, got {output_opset}"
98-
)
99-
100-
101-
@pytest.mark.parametrize("quant_mode", ["int8", "fp8", "int4"])
102-
def test_opset_above_minimum(tmp_path, quant_mode):
103-
"""Test that specifying opset at or above minimum is respected."""
81+
# Skip if required opset exceeds onnxruntime support
82+
max_opset = max(export_opset, request_opset, expected_opset)
83+
if max_opset >= 22:
84+
ort_version = version.parse(onnxruntime.__version__)
85+
if ort_version < ORT_VERSION_FOR_OPSET_22:
86+
pytest.skip(
87+
f"Opset {max_opset} requires onnxruntime >= {ORT_VERSION_FOR_OPSET_22}, have {ort_version}"
88+
)
89+
90+
# Setup: create and export model
10491
model_torch = SimpleMLP()
10592
input_tensor = torch.randn(2, 16, 16)
106-
107-
min_opset = MIN_OPSET[quant_mode]
108-
target_opset = min_opset + 1
109-
11093
onnx_path = os.path.join(tmp_path, "model.onnx")
111-
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path)
94+
export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path, opset=export_opset)
11295

113-
moq.quantize(onnx_path, quantize_mode=quant_mode, opset=target_opset)
96+
# Run quantization
97+
moq.quantize(onnx_path, quantize_mode=quant_mode, opset=request_opset)
11498

115-
# Verify output model has the requested opset
99+
# Verify output opset
116100
output_onnx_path = onnx_path.replace(".onnx", ".quant.onnx")
117101
output_model = onnx.load(output_onnx_path)
118102
output_opset = get_opset_version(output_model)
119103

120-
assert output_opset == target_opset, (
121-
f"Expected opset {target_opset} for {quant_mode}, got {output_opset}"
104+
assert output_opset == expected_opset, (
105+
f"[{scenario_name}] Expected opset {expected_opset} for {quant_mode}, got {output_opset}"
122106
)

0 commit comments

Comments
 (0)