2323import os
2424import sys
2525import tempfile
26- import unittest
2726
28- # Add parent directory to path
29- sys .path .insert (0 , os .path .dirname (os .path .dirname (os .path .abspath (__file__ ))))
27+ # Add parent and current directory to path
28+ _test_dir = os .path .dirname (os .path .abspath (__file__ ))
29+ sys .path .insert (0 , os .path .dirname (_test_dir ))
30+ sys .path .insert (0 , _test_dir )
3031
32+ import models as _test_models
3133import onnx
3234import onnx_graphsurgeon as gs
33- from onnx import helper
35+ import pytest
3436
3537from modelopt .onnx .quantization .autotune import Config , QDQAutotuner , RegionPattern
3638from modelopt .onnx .quantization .autotune .common import PatternCache , RegionType
3739
3840
39- def create_simple_conv_model ():
40- """
41- Create a simple ONNX model: Input -> Conv -> Relu -> Output.
42-
43- This is a minimal model for testing autotuner initialization.
44- """
45- # Input
46- input_tensor = helper .make_tensor_value_info ("input" , onnx .TensorProto .FLOAT , [1 , 3 , 224 , 224 ])
41+ @pytest .fixture
42+ def simple_conv_model ():
43+ """Simple ONNX model: Input -> Conv -> Relu -> Output. Created via models.py."""
44+ return _test_models ._create_simple_conv_onnx_model ()
4745
48- # Output
49- output_tensor = helper .make_tensor_value_info (
50- "output" , onnx .TensorProto .FLOAT , [1 , 64 , 224 , 224 ]
51- )
5246
53- # Conv node
54- conv_node = helper .make_node (
55- "Conv" , inputs = ["input" , "conv_weight" ], outputs = ["conv_out" ], name = "conv"
56- )
47+ def _create_test_config ():
48+ """
49+ Create a reasonable config for testing.
5750
58- # Relu node
59- relu_node = helper .make_node ("Relu" , inputs = ["conv_out" ], outputs = ["output" ], name = "relu" )
60-
61- # Create graph
62- graph = helper .make_graph (
63- [conv_node , relu_node ],
64- "simple_conv" ,
65- [input_tensor ],
66- [output_tensor ],
67- initializer = [
68- helper .make_tensor (
69- "conv_weight" , onnx .TensorProto .FLOAT , [64 , 3 , 3 , 3 ], [0.1 ] * (64 * 3 * 3 * 3 )
70- )
71- ],
51+ Uses sensible defaults suitable for unit tests:
52+ - verbose=False: Keep test output clean
53+ - maximum_sequence_region_size=50: Allow larger test regions
54+ - Other parameters: Match Config defaults for typical behavior
55+ """
56+ return Config (
57+ # Logging
58+ verbose = False ,
59+ # Performance Requirements
60+ # Quantization Parameters
61+ default_q_scale = 0.1 ,
62+ default_q_zero_point = 0 ,
63+ default_quant_type = "int8" ,
64+ # Region Builder Settings
65+ maximum_sequence_region_size = 50 ,
66+ minimum_topdown_search_size = 10 ,
67+ # Scheme Generation Settings
68+ top_percent_to_mutate = 0.1 ,
69+ minimum_schemes_to_mutate = 10 ,
70+ maximum_mutations = 3 ,
71+ maximum_generation_attempts = 100 ,
72+ # Pattern Cache Settings
73+ pattern_cache_minimum_distance = 4 ,
74+ pattern_cache_max_entries_per_pattern = 32 ,
7275 )
7376
74- # Create model
75- model = helper .make_model (graph , producer_name = "test" )
76- return model
77-
7877
79- class TestQDQAutotuner ( unittest . TestCase ) :
78+ class TestQDQAutotuner :
8079 """Test QDQAutotuner functionality."""
8180
82- @staticmethod
83- def _create_test_config ():
84- """
85- Create a reasonable config for testing.
86-
87- Uses sensible defaults suitable for unit tests:
88- - verbose=False: Keep test output clean
89- - maximum_sequence_region_size=50: Allow larger test regions
90- - Other parameters: Match Config defaults for typical behavior
91- """
92- return Config (
93- # Logging
94- verbose = False ,
95- # Performance Requirements
96- # Quantization Parameters
97- default_q_scale = 0.1 ,
98- default_q_zero_point = 0 ,
99- default_quant_type = "int8" ,
100- # Region Builder Settings
101- maximum_sequence_region_size = 50 ,
102- minimum_topdown_search_size = 10 ,
103- # Scheme Generation Settings
104- top_percent_to_mutate = 0.1 ,
105- minimum_schemes_to_mutate = 10 ,
106- maximum_mutations = 3 ,
107- maximum_generation_attempts = 100 ,
108- # Pattern Cache Settings
109- pattern_cache_minimum_distance = 4 ,
110- pattern_cache_max_entries_per_pattern = 32 ,
111- )
112-
113- def test_creation_with_onnx_model (self ):
81+ def test_creation_with_onnx_model (self , simple_conv_model ):
11482 """Test creating autotuner with ONNX ModelProto."""
115- model = create_simple_conv_model ()
116- autotuner = QDQAutotuner (model )
83+ autotuner = QDQAutotuner (simple_conv_model )
11784
11885 assert autotuner is not None
11986 assert autotuner .onnx_model is not None
12087 assert autotuner .graph is not None
12188
122- def test_creation_with_gs_graph (self ):
89+ def test_creation_with_gs_graph (self , simple_conv_model ):
12390 """Test creating autotuner with GraphSurgeon graph."""
124- model = create_simple_conv_model ()
125- gs_graph = gs .import_onnx (model )
126-
91+ gs_graph = gs .import_onnx (simple_conv_model )
12792 autotuner = QDQAutotuner (gs_graph )
12893
12994 assert autotuner is not None
13095 assert autotuner .graph is not None
13196
132- def test_initialize_with_default_config (self ):
97+ def test_initialize_with_default_config (self , simple_conv_model ):
13398 """Test initialization with default test config."""
134- model = create_simple_conv_model ()
135- autotuner = QDQAutotuner (model )
99+ autotuner = QDQAutotuner (simple_conv_model )
136100
137- config = self . _create_test_config ()
101+ config = _create_test_config ()
138102 autotuner .initialize (config )
139103
140104 # Should have provided config
@@ -144,10 +108,9 @@ def test_initialize_with_default_config(self):
144108 # Should have discovered regions
145109 assert len (autotuner .regions ) > 0
146110
147- def test_initialize_with_config (self ):
111+ def test_initialize_with_config (self , simple_conv_model ):
148112 """Test initialization with custom config (different from default)."""
149- model = create_simple_conv_model ()
150- autotuner = QDQAutotuner (model )
113+ autotuner = QDQAutotuner (simple_conv_model )
151114
152115 # Create custom config with different values
153116 config = Config (
@@ -180,23 +143,21 @@ def test_initialize_with_config(self):
180143 assert autotuner .config .pattern_cache_minimum_distance == 2
181144 assert autotuner .config .pattern_cache_max_entries_per_pattern == 16
182145
183- def test_initialize_with_pattern_cache (self ):
146+ def test_initialize_with_pattern_cache (self , simple_conv_model ):
184147 """Test initialization with pattern cache."""
185- model = create_simple_conv_model ()
186- autotuner = QDQAutotuner (model )
148+ autotuner = QDQAutotuner (simple_conv_model )
187149
188- config = self . _create_test_config ()
150+ config = _create_test_config ()
189151 pattern_cache = PatternCache ()
190152 autotuner .initialize (config , pattern_cache = pattern_cache )
191153
192154 assert autotuner .pattern_cache is not None
193155
194- def test_region_discovery (self ):
156+ def test_region_discovery (self , simple_conv_model ):
195157 """Test that regions are automatically discovered."""
196- model = create_simple_conv_model ()
197- autotuner = QDQAutotuner (model )
158+ autotuner = QDQAutotuner (simple_conv_model )
198159
199- config = self . _create_test_config ()
160+ config = _create_test_config ()
200161 autotuner .initialize (config )
201162
202163 # Should discover at least one region
@@ -207,11 +168,10 @@ def test_region_discovery(self):
207168 assert region .get_id () is not None
208169 assert region .get_type () in [RegionType .LEAF , RegionType .COMPOSITE , RegionType .ROOT ]
209170
210- def test_export_baseline_model (self ):
171+ def test_export_baseline_model (self , simple_conv_model ):
211172 """Test exporting baseline model without Q/DQ."""
212- model = create_simple_conv_model ()
213- autotuner = QDQAutotuner (model )
214- config = self ._create_test_config ()
173+ autotuner = QDQAutotuner (simple_conv_model )
174+ config = _create_test_config ()
215175 autotuner .initialize (config )
216176
217177 with tempfile .NamedTemporaryFile (suffix = ".onnx" , delete = False ) as f :
@@ -229,11 +189,10 @@ def test_export_baseline_model(self):
229189 if os .path .exists (output_path ):
230190 os .unlink (output_path )
231191
232- def test_set_profile_region (self ):
192+ def test_set_profile_region (self , simple_conv_model ):
233193 """Test setting a region for profiling."""
234- model = create_simple_conv_model ()
235- autotuner = QDQAutotuner (model )
236- config = self ._create_test_config ()
194+ autotuner = QDQAutotuner (simple_conv_model )
195+ config = _create_test_config ()
237196 autotuner .initialize (config )
238197
239198 if len (autotuner .regions ) > 0 :
@@ -243,13 +202,12 @@ def test_set_profile_region(self):
243202 assert autotuner .current_profile_region == region
244203 assert autotuner .current_profile_pattern_schemes is not None
245204 else :
246- self . skipTest ("No regions discovered" )
205+ pytest . skip ("No regions discovered" )
247206
248- def test_generate_scheme (self ):
207+ def test_generate_scheme (self , simple_conv_model ):
249208 """Test generating an insertion scheme."""
250- model = create_simple_conv_model ()
251- autotuner = QDQAutotuner (model )
252- config = self ._create_test_config ()
209+ autotuner = QDQAutotuner (simple_conv_model )
210+ config = _create_test_config ()
253211 autotuner .initialize (config )
254212
255213 if len (autotuner .regions ) > 0 :
@@ -260,24 +218,22 @@ def test_generate_scheme(self):
260218 # Should return a valid index (>= 0) or -1 if no more unique schemes
261219 assert isinstance (scheme_idx , int )
262220 else :
263- self . skipTest ("No regions discovered" )
221+ pytest . skip ("No regions discovered" )
264222
265- def test_submit_latency (self ):
223+ def test_submit_latency (self , simple_conv_model ):
266224 """Test submitting performance measurement."""
267- model = create_simple_conv_model ()
268- autotuner = QDQAutotuner (model )
269- config = self ._create_test_config ()
225+ autotuner = QDQAutotuner (simple_conv_model )
226+ config = _create_test_config ()
270227 autotuner .initialize (config )
271228 # Submit baseline latency
272229 autotuner .submit (10.5 )
273230 # Baseline should be recorded
274231 assert autotuner .baseline_latency_ms == 10.5
275232
276- def test_save_and_load_state (self ):
233+ def test_save_and_load_state (self , simple_conv_model ):
277234 """Test saving and loading autotuner state."""
278- model = create_simple_conv_model ()
279- autotuner = QDQAutotuner (model )
280- config = self ._create_test_config ()
235+ autotuner = QDQAutotuner (simple_conv_model )
236+ config = _create_test_config ()
281237 autotuner .initialize (config )
282238
283239 # Submit some results
@@ -292,8 +248,8 @@ def test_save_and_load_state(self):
292248 assert os .path .exists (state_path )
293249
294250 # Create new autotuner and load state
295- autotuner2 = QDQAutotuner (model )
296- config2 = self . _create_test_config ()
251+ autotuner2 = QDQAutotuner (simple_conv_model )
252+ config2 = _create_test_config ()
297253 autotuner2 .initialize (config2 )
298254 autotuner2 .load_state (state_path )
299255
@@ -303,11 +259,10 @@ def test_save_and_load_state(self):
303259 if os .path .exists (state_path ):
304260 os .unlink (state_path )
305261
306- def test_regions_prioritization (self ):
262+ def test_regions_prioritization (self , simple_conv_model ):
307263 """Test that LEAF regions are prioritized."""
308- model = create_simple_conv_model ()
309- autotuner = QDQAutotuner (model )
310- config = self ._create_test_config ()
264+ autotuner = QDQAutotuner (simple_conv_model )
265+ config = _create_test_config ()
311266 autotuner .initialize (config )
312267
313268 # Check that LEAF regions come before non-LEAF
@@ -322,11 +277,10 @@ def test_regions_prioritization(self):
322277 # All LEAF should come before non-LEAF
323278 assert max (leaf_indices ) < min (non_leaf_indices )
324279
325- def test_profiled_patterns_tracking (self ):
280+ def test_profiled_patterns_tracking (self , simple_conv_model ):
326281 """Test that profiled patterns are tracked."""
327- model = create_simple_conv_model ()
328- autotuner = QDQAutotuner (model )
329- config = self ._create_test_config ()
282+ autotuner = QDQAutotuner (simple_conv_model )
283+ config = _create_test_config ()
330284 autotuner .initialize (config )
331285 autotuner .submit (10.0 )
332286
@@ -342,4 +296,4 @@ def test_profiled_patterns_tracking(self):
342296 profiled_patterns = [p .pattern .signature for p in autotuner .profiled_patterns ]
343297 assert pattern_sig in profiled_patterns
344298 else :
345- self . skipTest ("No regions discovered" )
299+ pytest . skip ("No regions discovered" )
0 commit comments