Skip to content

Commit 0f41470

Browse files
committed
resolve comments
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent f93f5d5 commit 0f41470

2 files changed

Lines changed: 126 additions & 125 deletions

File tree

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Shared test ONNX models for autotuner unit tests.
18+
19+
Model creation functions live here; tests import and call them directly.
20+
"""
21+
22+
import onnx
23+
from onnx import helper
24+
25+
26+
def _create_simple_conv_onnx_model():
27+
"""Build ONNX model: Input -> Conv -> Relu -> Output (minimal for autotuner tests)."""
28+
input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224])
29+
output_tensor = helper.make_tensor_value_info(
30+
"output", onnx.TensorProto.FLOAT, [1, 64, 224, 224]
31+
)
32+
conv_node = helper.make_node(
33+
"Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv"
34+
)
35+
relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu")
36+
graph = helper.make_graph(
37+
[conv_node, relu_node],
38+
"simple_conv",
39+
[input_tensor],
40+
[output_tensor],
41+
initializer=[
42+
helper.make_tensor(
43+
"conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3)
44+
)
45+
],
46+
)
47+
return helper.make_model(graph, producer_name="test")

tests/unit/onnx/quantization/autotune/autotune/test_autotuner.py

Lines changed: 79 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -23,118 +23,82 @@
2323
import os
2424
import sys
2525
import tempfile
26-
import unittest
2726

28-
# Add parent directory to path
29-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
27+
# Add parent and current directory to path
28+
_test_dir = os.path.dirname(os.path.abspath(__file__))
29+
sys.path.insert(0, os.path.dirname(_test_dir))
30+
sys.path.insert(0, _test_dir)
3031

32+
import models as _test_models
3133
import onnx
3234
import onnx_graphsurgeon as gs
33-
from onnx import helper
35+
import pytest
3436

3537
from modelopt.onnx.quantization.autotune import Config, QDQAutotuner, RegionPattern
3638
from modelopt.onnx.quantization.autotune.common import PatternCache, RegionType
3739

3840

39-
def create_simple_conv_model():
40-
"""
41-
Create a simple ONNX model: Input -> Conv -> Relu -> Output.
42-
43-
This is a minimal model for testing autotuner initialization.
44-
"""
45-
# Input
46-
input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224])
41+
@pytest.fixture
42+
def simple_conv_model():
43+
"""Simple ONNX model: Input -> Conv -> Relu -> Output. Created via models.py."""
44+
return _test_models._create_simple_conv_onnx_model()
4745

48-
# Output
49-
output_tensor = helper.make_tensor_value_info(
50-
"output", onnx.TensorProto.FLOAT, [1, 64, 224, 224]
51-
)
5246

53-
# Conv node
54-
conv_node = helper.make_node(
55-
"Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv"
56-
)
47+
def _create_test_config():
48+
"""
49+
Create a reasonable config for testing.
5750
58-
# Relu node
59-
relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu")
60-
61-
# Create graph
62-
graph = helper.make_graph(
63-
[conv_node, relu_node],
64-
"simple_conv",
65-
[input_tensor],
66-
[output_tensor],
67-
initializer=[
68-
helper.make_tensor(
69-
"conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3)
70-
)
71-
],
51+
Uses sensible defaults suitable for unit tests:
52+
- verbose=False: Keep test output clean
53+
- maximum_sequence_region_size=50: Allow larger test regions
54+
- Other parameters: Match Config defaults for typical behavior
55+
"""
56+
return Config(
57+
# Logging
58+
verbose=False,
59+
# Performance Requirements
60+
# Quantization Parameters
61+
default_q_scale=0.1,
62+
default_q_zero_point=0,
63+
default_quant_type="int8",
64+
# Region Builder Settings
65+
maximum_sequence_region_size=50,
66+
minimum_topdown_search_size=10,
67+
# Scheme Generation Settings
68+
top_percent_to_mutate=0.1,
69+
minimum_schemes_to_mutate=10,
70+
maximum_mutations=3,
71+
maximum_generation_attempts=100,
72+
# Pattern Cache Settings
73+
pattern_cache_minimum_distance=4,
74+
pattern_cache_max_entries_per_pattern=32,
7275
)
7376

74-
# Create model
75-
model = helper.make_model(graph, producer_name="test")
76-
return model
77-
7877

79-
class TestQDQAutotuner(unittest.TestCase):
78+
class TestQDQAutotuner:
8079
"""Test QDQAutotuner functionality."""
8180

82-
@staticmethod
83-
def _create_test_config():
84-
"""
85-
Create a reasonable config for testing.
86-
87-
Uses sensible defaults suitable for unit tests:
88-
- verbose=False: Keep test output clean
89-
- maximum_sequence_region_size=50: Allow larger test regions
90-
- Other parameters: Match Config defaults for typical behavior
91-
"""
92-
return Config(
93-
# Logging
94-
verbose=False,
95-
# Performance Requirements
96-
# Quantization Parameters
97-
default_q_scale=0.1,
98-
default_q_zero_point=0,
99-
default_quant_type="int8",
100-
# Region Builder Settings
101-
maximum_sequence_region_size=50,
102-
minimum_topdown_search_size=10,
103-
# Scheme Generation Settings
104-
top_percent_to_mutate=0.1,
105-
minimum_schemes_to_mutate=10,
106-
maximum_mutations=3,
107-
maximum_generation_attempts=100,
108-
# Pattern Cache Settings
109-
pattern_cache_minimum_distance=4,
110-
pattern_cache_max_entries_per_pattern=32,
111-
)
112-
113-
def test_creation_with_onnx_model(self):
81+
def test_creation_with_onnx_model(self, simple_conv_model):
11482
"""Test creating autotuner with ONNX ModelProto."""
115-
model = create_simple_conv_model()
116-
autotuner = QDQAutotuner(model)
83+
autotuner = QDQAutotuner(simple_conv_model)
11784

11885
assert autotuner is not None
11986
assert autotuner.onnx_model is not None
12087
assert autotuner.graph is not None
12188

122-
def test_creation_with_gs_graph(self):
89+
def test_creation_with_gs_graph(self, simple_conv_model):
12390
"""Test creating autotuner with GraphSurgeon graph."""
124-
model = create_simple_conv_model()
125-
gs_graph = gs.import_onnx(model)
126-
91+
gs_graph = gs.import_onnx(simple_conv_model)
12792
autotuner = QDQAutotuner(gs_graph)
12893

12994
assert autotuner is not None
13095
assert autotuner.graph is not None
13196

132-
def test_initialize_with_default_config(self):
97+
def test_initialize_with_default_config(self, simple_conv_model):
13398
"""Test initialization with default test config."""
134-
model = create_simple_conv_model()
135-
autotuner = QDQAutotuner(model)
99+
autotuner = QDQAutotuner(simple_conv_model)
136100

137-
config = self._create_test_config()
101+
config = _create_test_config()
138102
autotuner.initialize(config)
139103

140104
# Should have provided config
@@ -144,10 +108,9 @@ def test_initialize_with_default_config(self):
144108
# Should have discovered regions
145109
assert len(autotuner.regions) > 0
146110

147-
def test_initialize_with_config(self):
111+
def test_initialize_with_config(self, simple_conv_model):
148112
"""Test initialization with custom config (different from default)."""
149-
model = create_simple_conv_model()
150-
autotuner = QDQAutotuner(model)
113+
autotuner = QDQAutotuner(simple_conv_model)
151114

152115
# Create custom config with different values
153116
config = Config(
@@ -180,23 +143,21 @@ def test_initialize_with_config(self):
180143
assert autotuner.config.pattern_cache_minimum_distance == 2
181144
assert autotuner.config.pattern_cache_max_entries_per_pattern == 16
182145

183-
def test_initialize_with_pattern_cache(self):
146+
def test_initialize_with_pattern_cache(self, simple_conv_model):
184147
"""Test initialization with pattern cache."""
185-
model = create_simple_conv_model()
186-
autotuner = QDQAutotuner(model)
148+
autotuner = QDQAutotuner(simple_conv_model)
187149

188-
config = self._create_test_config()
150+
config = _create_test_config()
189151
pattern_cache = PatternCache()
190152
autotuner.initialize(config, pattern_cache=pattern_cache)
191153

192154
assert autotuner.pattern_cache is not None
193155

194-
def test_region_discovery(self):
156+
def test_region_discovery(self, simple_conv_model):
195157
"""Test that regions are automatically discovered."""
196-
model = create_simple_conv_model()
197-
autotuner = QDQAutotuner(model)
158+
autotuner = QDQAutotuner(simple_conv_model)
198159

199-
config = self._create_test_config()
160+
config = _create_test_config()
200161
autotuner.initialize(config)
201162

202163
# Should discover at least one region
@@ -207,11 +168,10 @@ def test_region_discovery(self):
207168
assert region.get_id() is not None
208169
assert region.get_type() in [RegionType.LEAF, RegionType.COMPOSITE, RegionType.ROOT]
209170

210-
def test_export_baseline_model(self):
171+
def test_export_baseline_model(self, simple_conv_model):
211172
"""Test exporting baseline model without Q/DQ."""
212-
model = create_simple_conv_model()
213-
autotuner = QDQAutotuner(model)
214-
config = self._create_test_config()
173+
autotuner = QDQAutotuner(simple_conv_model)
174+
config = _create_test_config()
215175
autotuner.initialize(config)
216176

217177
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
@@ -229,11 +189,10 @@ def test_export_baseline_model(self):
229189
if os.path.exists(output_path):
230190
os.unlink(output_path)
231191

232-
def test_set_profile_region(self):
192+
def test_set_profile_region(self, simple_conv_model):
233193
"""Test setting a region for profiling."""
234-
model = create_simple_conv_model()
235-
autotuner = QDQAutotuner(model)
236-
config = self._create_test_config()
194+
autotuner = QDQAutotuner(simple_conv_model)
195+
config = _create_test_config()
237196
autotuner.initialize(config)
238197

239198
if len(autotuner.regions) > 0:
@@ -243,13 +202,12 @@ def test_set_profile_region(self):
243202
assert autotuner.current_profile_region == region
244203
assert autotuner.current_profile_pattern_schemes is not None
245204
else:
246-
self.skipTest("No regions discovered")
205+
pytest.skip("No regions discovered")
247206

248-
def test_generate_scheme(self):
207+
def test_generate_scheme(self, simple_conv_model):
249208
"""Test generating an insertion scheme."""
250-
model = create_simple_conv_model()
251-
autotuner = QDQAutotuner(model)
252-
config = self._create_test_config()
209+
autotuner = QDQAutotuner(simple_conv_model)
210+
config = _create_test_config()
253211
autotuner.initialize(config)
254212

255213
if len(autotuner.regions) > 0:
@@ -260,24 +218,22 @@ def test_generate_scheme(self):
260218
# Should return a valid index (>= 0) or -1 if no more unique schemes
261219
assert isinstance(scheme_idx, int)
262220
else:
263-
self.skipTest("No regions discovered")
221+
pytest.skip("No regions discovered")
264222

265-
def test_submit_latency(self):
223+
def test_submit_latency(self, simple_conv_model):
266224
"""Test submitting performance measurement."""
267-
model = create_simple_conv_model()
268-
autotuner = QDQAutotuner(model)
269-
config = self._create_test_config()
225+
autotuner = QDQAutotuner(simple_conv_model)
226+
config = _create_test_config()
270227
autotuner.initialize(config)
271228
# Submit baseline latency
272229
autotuner.submit(10.5)
273230
# Baseline should be recorded
274231
assert autotuner.baseline_latency_ms == 10.5
275232

276-
def test_save_and_load_state(self):
233+
def test_save_and_load_state(self, simple_conv_model):
277234
"""Test saving and loading autotuner state."""
278-
model = create_simple_conv_model()
279-
autotuner = QDQAutotuner(model)
280-
config = self._create_test_config()
235+
autotuner = QDQAutotuner(simple_conv_model)
236+
config = _create_test_config()
281237
autotuner.initialize(config)
282238

283239
# Submit some results
@@ -292,8 +248,8 @@ def test_save_and_load_state(self):
292248
assert os.path.exists(state_path)
293249

294250
# Create new autotuner and load state
295-
autotuner2 = QDQAutotuner(model)
296-
config2 = self._create_test_config()
251+
autotuner2 = QDQAutotuner(simple_conv_model)
252+
config2 = _create_test_config()
297253
autotuner2.initialize(config2)
298254
autotuner2.load_state(state_path)
299255

@@ -303,11 +259,10 @@ def test_save_and_load_state(self):
303259
if os.path.exists(state_path):
304260
os.unlink(state_path)
305261

306-
def test_regions_prioritization(self):
262+
def test_regions_prioritization(self, simple_conv_model):
307263
"""Test that LEAF regions are prioritized."""
308-
model = create_simple_conv_model()
309-
autotuner = QDQAutotuner(model)
310-
config = self._create_test_config()
264+
autotuner = QDQAutotuner(simple_conv_model)
265+
config = _create_test_config()
311266
autotuner.initialize(config)
312267

313268
# Check that LEAF regions come before non-LEAF
@@ -322,11 +277,10 @@ def test_regions_prioritization(self):
322277
# All LEAF should come before non-LEAF
323278
assert max(leaf_indices) < min(non_leaf_indices)
324279

325-
def test_profiled_patterns_tracking(self):
280+
def test_profiled_patterns_tracking(self, simple_conv_model):
326281
"""Test that profiled patterns are tracked."""
327-
model = create_simple_conv_model()
328-
autotuner = QDQAutotuner(model)
329-
config = self._create_test_config()
282+
autotuner = QDQAutotuner(simple_conv_model)
283+
config = _create_test_config()
330284
autotuner.initialize(config)
331285
autotuner.submit(10.0)
332286

@@ -342,4 +296,4 @@ def test_profiled_patterns_tracking(self):
342296
profiled_patterns = [p.pattern.signature for p in autotuner.profiled_patterns]
343297
assert pattern_sig in profiled_patterns
344298
else:
345-
self.skipTest("No regions discovered")
299+
pytest.skip("No regions discovered")

0 commit comments

Comments
 (0)