Skip to content

Commit db2f2cc

Browse files
committed
add unittest for region_inspect
Signed-off-by: Will Guo <willg@nvidia.com>
1 parent 610d9a9 commit db2f2cc

File tree

1 file changed

+367
-0
lines changed

1 file changed

+367
-0
lines changed
Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for region_inspect module."""
17+
18+
import os
19+
from unittest.mock import Mock, patch
20+
21+
import numpy as np
22+
import onnx
23+
import pytest
24+
from onnx import TensorProto, helper, numpy_helper
25+
26+
27+
def create_simple_onnx_model():
28+
"""Create a simple ONNX model for testing.
29+
30+
Creates a model with: Input -> Conv -> Relu -> MatMul -> Output
31+
"""
32+
# Create input
33+
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 224, 224])
34+
output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1000])
35+
36+
# Create weights for Conv
37+
conv_weight = np.random.randn(64, 3, 7, 7).astype(np.float32)
38+
conv_weight_tensor = numpy_helper.from_array(conv_weight, "conv_weight")
39+
40+
# Create weights for MatMul
41+
matmul_weight = np.random.randn(64, 1000).astype(np.float32)
42+
matmul_weight_tensor = numpy_helper.from_array(matmul_weight, "matmul_weight")
43+
44+
# Create nodes
45+
conv_node = helper.make_node(
46+
"Conv",
47+
inputs=["input", "conv_weight"],
48+
outputs=["conv_output"],
49+
kernel_shape=[7, 7],
50+
strides=[2, 2],
51+
pads=[3, 3, 3, 3],
52+
)
53+
54+
relu_node = helper.make_node(
55+
"Relu",
56+
inputs=["conv_output"],
57+
outputs=["relu_output"],
58+
)
59+
60+
flatten_node = helper.make_node(
61+
"Flatten",
62+
inputs=["relu_output"],
63+
outputs=["flatten_output"],
64+
axis=1,
65+
)
66+
67+
matmul_node = helper.make_node(
68+
"MatMul",
69+
inputs=["flatten_output", "matmul_weight"],
70+
outputs=["output"],
71+
)
72+
73+
# Create graph
74+
graph = helper.make_graph(
75+
[conv_node, relu_node, flatten_node, matmul_node],
76+
"test_model",
77+
[input_tensor],
78+
[output_tensor],
79+
[conv_weight_tensor, matmul_weight_tensor],
80+
)
81+
82+
# Create model
83+
model = helper.make_model(graph, producer_name="test")
84+
model.opset_import[0].version = 13
85+
86+
return model
87+
88+
89+
@pytest.fixture
90+
def simple_onnx_model():
91+
"""Fixture that provides a simple ONNX model."""
92+
return create_simple_onnx_model()
93+
94+
95+
@pytest.fixture
96+
def onnx_model_file(tmp_path, simple_onnx_model):
97+
"""Fixture that provides a path to a saved ONNX model."""
98+
model_path = os.path.join(tmp_path, "test_model.onnx")
99+
onnx.save(simple_onnx_model, model_path)
100+
return model_path
101+
102+
103+
class TestRegionInspectImports:
104+
"""Test that the region_inspect module can be imported."""
105+
106+
def test_module_imports(self):
107+
"""Test that the module imports without errors when dependencies exist."""
108+
# This test will skip if the required dependencies don't exist
109+
try:
110+
from modelopt.onnx.quantization.autotune import region_inspect
111+
112+
assert hasattr(region_inspect, "inspect_region_search")
113+
assert hasattr(region_inspect, "main")
114+
except ImportError as e:
115+
pytest.skip(f"Required dependencies not available: {e}")
116+
117+
118+
class TestRegionInspectWithMocks:
119+
"""Test region_inspect functionality with mocked dependencies."""
120+
121+
@patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch")
122+
@patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations")
123+
def test_inspect_region_search_basic(
124+
self, mock_has_quantizable, mock_combined_search, onnx_model_file
125+
):
126+
"""Test basic functionality of inspect_region_search with mocked dependencies."""
127+
try:
128+
from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search
129+
except ImportError:
130+
pytest.skip("Required dependencies not available")
131+
132+
# Setup mocks
133+
mock_region = Mock()
134+
mock_region.type = Mock(value="LEAF")
135+
mock_region.inputs = ["input1"]
136+
mock_region.outputs = ["output1"]
137+
mock_region.children = []
138+
mock_region.get_region_nodes_and_descendants.return_value = [Mock(), Mock()]
139+
mock_region.get_children.return_value = []
140+
141+
mock_search_instance = Mock()
142+
mock_search_instance.search_regions.return_value = [mock_region]
143+
mock_search_instance.print_tree = Mock()
144+
mock_combined_search.return_value = mock_search_instance
145+
146+
mock_has_quantizable.return_value = True
147+
148+
# Call the function
149+
result = inspect_region_search(
150+
onnx_path=onnx_model_file, max_sequence_size=10, include_all_regions=False
151+
)
152+
153+
# Verify the function was called correctly
154+
assert mock_combined_search.called
155+
assert mock_search_instance.search_regions.called
156+
assert isinstance(result, list)
157+
158+
@patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch")
159+
@patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations")
160+
def test_inspect_region_search_with_custom_params(
161+
self, mock_has_quantizable, mock_combined_search, onnx_model_file
162+
):
163+
"""Test inspect_region_search with custom parameters."""
164+
try:
165+
from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search
166+
except ImportError:
167+
pytest.skip("Required dependencies not available")
168+
169+
# Setup mocks
170+
mock_region = Mock()
171+
mock_region.type = Mock(value="COMPOSITE")
172+
mock_region.inputs = ["input1"]
173+
mock_region.outputs = ["output1"]
174+
mock_region.children = []
175+
mock_region.get_region_nodes_and_descendants.return_value = [Mock()]
176+
mock_region.get_children.return_value = []
177+
178+
mock_search_instance = Mock()
179+
mock_search_instance.search_regions.return_value = [mock_region]
180+
mock_search_instance.print_tree = Mock()
181+
mock_combined_search.return_value = mock_search_instance
182+
183+
mock_has_quantizable.return_value = True
184+
185+
# Call with custom parameters
186+
result = inspect_region_search(
187+
onnx_path=onnx_model_file, max_sequence_size=20, include_all_regions=True
188+
)
189+
190+
# Verify custom parameters were used
191+
assert mock_combined_search.called
192+
call_kwargs = mock_combined_search.call_args[1]
193+
assert call_kwargs.get("maximum_sequence_region_size") == 20
194+
assert isinstance(result, list)
195+
196+
@patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch")
197+
@patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations")
198+
def test_inspect_region_search_filtering(
199+
self, mock_has_quantizable, mock_combined_search, onnx_model_file
200+
):
201+
"""Test that regions without quantizable operations are filtered out."""
202+
try:
203+
from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search
204+
except ImportError:
205+
pytest.skip("Required dependencies not available")
206+
207+
# Setup mocks - one region with quantizable ops, one without
208+
mock_region_quantizable = Mock()
209+
mock_region_quantizable.type = Mock(value="LEAF")
210+
mock_region_quantizable.inputs = ["input1"]
211+
mock_region_quantizable.outputs = ["output1"]
212+
mock_region_quantizable.get_region_nodes_and_descendants.return_value = [Mock()]
213+
mock_region_quantizable.get_children.return_value = []
214+
215+
mock_region_non_quantizable = Mock()
216+
mock_region_non_quantizable.type = Mock(value="LEAF")
217+
mock_region_non_quantizable.inputs = ["input2"]
218+
mock_region_non_quantizable.outputs = ["output2"]
219+
mock_region_non_quantizable.get_region_nodes_and_descendants.return_value = [Mock()]
220+
mock_region_non_quantizable.get_children.return_value = []
221+
222+
mock_search_instance = Mock()
223+
mock_search_instance.search_regions.return_value = [
224+
mock_region_quantizable,
225+
mock_region_non_quantizable,
226+
]
227+
mock_search_instance.print_tree = Mock()
228+
mock_combined_search.return_value = mock_search_instance
229+
230+
# First region has quantizable ops, second doesn't
231+
mock_has_quantizable.side_effect = [True, False]
232+
233+
# Call with filtering enabled
234+
result = inspect_region_search(
235+
onnx_path=onnx_model_file, max_sequence_size=10, include_all_regions=False
236+
)
237+
238+
# Should only return the quantizable region
239+
assert len(result) == 1
240+
241+
242+
class TestRegionInspectMain:
243+
"""Test the main CLI entry point."""
244+
245+
@patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search")
246+
def test_main_success(self, mock_inspect, onnx_model_file):
247+
"""Test main function with successful execution."""
248+
try:
249+
from modelopt.onnx.quantization.autotune.region_inspect import main
250+
except ImportError:
251+
pytest.skip("Required dependencies not available")
252+
253+
mock_inspect.return_value = [Mock(), Mock()]
254+
255+
with patch("sys.argv", ["region_inspect", "--model", onnx_model_file]):
256+
exit_code = main()
257+
assert exit_code == 0
258+
assert mock_inspect.called
259+
260+
@patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search")
261+
def test_main_with_verbose(self, mock_inspect, onnx_model_file):
262+
"""Test main function with verbose flag."""
263+
try:
264+
from modelopt.onnx.quantization.autotune.region_inspect import main
265+
except ImportError:
266+
pytest.skip("Required dependencies not available")
267+
268+
mock_inspect.return_value = [Mock()]
269+
270+
with patch("sys.argv", ["region_inspect", "--model", onnx_model_file, "--verbose"]):
271+
exit_code = main()
272+
assert exit_code == 0
273+
274+
@patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search")
275+
def test_main_with_custom_max_sequence_size(self, mock_inspect, onnx_model_file):
276+
"""Test main function with custom max_sequence_size."""
277+
try:
278+
from modelopt.onnx.quantization.autotune.region_inspect import main
279+
except ImportError:
280+
pytest.skip("Required dependencies not available")
281+
282+
mock_inspect.return_value = [Mock()]
283+
284+
with patch(
285+
"sys.argv", ["region_inspect", "--model", onnx_model_file, "--max-sequence-size", "20"]
286+
):
287+
exit_code = main()
288+
assert exit_code == 0
289+
# Verify max_sequence_size parameter was passed
290+
call_kwargs = mock_inspect.call_args[1]
291+
assert call_kwargs.get("max_sequence_size") == 20
292+
293+
@patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search")
294+
def test_main_with_include_all_regions(self, mock_inspect, onnx_model_file):
295+
"""Test main function with include_all_regions flag."""
296+
try:
297+
from modelopt.onnx.quantization.autotune.region_inspect import main
298+
except ImportError:
299+
pytest.skip("Required dependencies not available")
300+
301+
mock_inspect.return_value = [Mock()]
302+
303+
with patch(
304+
"sys.argv", ["region_inspect", "--model", onnx_model_file, "--include-all-regions"]
305+
):
306+
exit_code = main()
307+
assert exit_code == 0
308+
# Verify include_all_regions parameter was passed
309+
call_kwargs = mock_inspect.call_args[1]
310+
assert call_kwargs.get("include_all_regions") is True
311+
312+
@patch("modelopt.onnx.quantization.autotune.region_inspect.inspect_region_search")
313+
def test_main_failure(self, mock_inspect, onnx_model_file):
314+
"""Test main function with execution failure."""
315+
try:
316+
from modelopt.onnx.quantization.autotune.region_inspect import main
317+
except ImportError:
318+
pytest.skip("Required dependencies not available")
319+
320+
mock_inspect.side_effect = Exception("Test error")
321+
322+
with patch("sys.argv", ["region_inspect", "--model", onnx_model_file]):
323+
exit_code = main()
324+
assert exit_code == 1
325+
326+
327+
class TestRegionInspectModelLoading:
328+
"""Test model loading functionality."""
329+
330+
@patch("modelopt.onnx.quantization.autotune.region_inspect.CombinedRegionSearch")
331+
@patch("modelopt.onnx.quantization.autotune.region_inspect.has_quantizable_operations")
332+
def test_loads_valid_onnx_model(
333+
self, mock_has_quantizable, mock_combined_search, onnx_model_file
334+
):
335+
"""Test that a valid ONNX model can be loaded."""
336+
try:
337+
from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search
338+
except ImportError:
339+
pytest.skip("Required dependencies not available")
340+
341+
# Setup minimal mocks
342+
mock_region = Mock()
343+
mock_region.type = Mock(value="LEAF")
344+
mock_region.inputs = []
345+
mock_region.outputs = []
346+
mock_region.get_region_nodes_and_descendants.return_value = []
347+
mock_region.get_children.return_value = []
348+
349+
mock_search_instance = Mock()
350+
mock_search_instance.search_regions.return_value = [mock_region]
351+
mock_search_instance.print_tree = Mock()
352+
mock_combined_search.return_value = mock_search_instance
353+
mock_has_quantizable.return_value = False
354+
355+
# Should not raise an exception
356+
result = inspect_region_search(onnx_model_file)
357+
assert isinstance(result, list)
358+
359+
def test_fails_on_nonexistent_file(self):
360+
"""Test that loading a non-existent file raises an error."""
361+
try:
362+
from modelopt.onnx.quantization.autotune.region_inspect import inspect_region_search
363+
except ImportError:
364+
pytest.skip("Required dependencies not available")
365+
366+
with pytest.raises(Exception): # Could be FileNotFoundError or other
367+
inspect_region_search("/nonexistent/path/to/model.onnx")

0 commit comments

Comments
 (0)