Skip to content

Commit 9d45a31

Browse files
committed
improve coverage
1 parent c1c9361 commit 9d45a31

9 files changed

Lines changed: 89 additions & 22 deletions

File tree

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
121121
)
122122
self.assertEqual(0, max_diff(c2, c2)["abs"])
123123
self.assertIsInstance(c2, transformers.cache_utils.EncoderDecoderCache)
124+
self.assertEqual(max_diff(c2, c2)["abs"], 0)
124125
flat, _spec = torch.utils._pytree.tree_flatten(c2)
125126
self.assertIsInstance(flat, list)
126127
self.assertEqual(len(flat), 12)

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
extract_subset_of_nodes,
3131
make_submodel,
3232
select_model_inputs_outputs,
33+
_enumerate_model_node_outputs,
3334
)
3435

3536

@@ -602,6 +603,12 @@ def _get_model_select(self):
602603
)
603604
return onnx_model
604605

606+
def test__enumerate_model_node_outputs(self):
607+
model = self._get_model_select()
608+
outputs1 = list(_enumerate_model_node_outputs(model, order=False))
609+
outputs2 = list(_enumerate_model_node_outputs(model, order=True))
610+
self.assertEqual(set(outputs1), set(outputs2))
611+
605612
def test_select_model_inputs_outputs(self):
606613
def enumerate_model_tensors(model):
607614
for tensor in _get_all_tensors(model):

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
from onnx_diagnostic.torch_models.validate import validate_model
1515

1616

17-
torch29_and_tr_main = not has_torch("2.9.9") and has_transformers("4.99999")
18-
19-
2017
class TestValidateModel(ExtTestCase):
21-
@unittest.skipIf(torch29_and_tr_main, "combination not working")
2218
@requires_transformers("4.53")
2319
@requires_torch("2.7.99")
2420
@requires_experimental()
@@ -40,12 +36,12 @@ def test_validate_tiny_llms_bfloat16(self):
4036
dtype="bfloat16",
4137
device="cuda",
4238
runtime="orteval",
39+
optimization="default+onnxruntime+os_ort",
4340
)
4441
self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-2)
4542
self.assertIn("onnx_filename", data)
4643
self.clean_dump()
4744

48-
@unittest.skipIf(torch29_and_tr_main, "combination not working")
4945
@requires_transformers("4.57") # 4.53 works for some jobs fails due to no space left
5046
@requires_torch("2.9.99") # 2.9 works for some jobs fails due to no space left
5147
@requires_experimental()
@@ -68,7 +64,6 @@ def test_validate_microsoft_phi4_reasoning(self):
6864
self.assertIn("onnx_filename", data)
6965
self.clean_dump()
7066

71-
@unittest.skipIf(torch29_and_tr_main, "combination not working")
7267
@requires_transformers("4.53")
7368
@requires_torch("2.8.99")
7469
@requires_experimental()

_unittests/ut_torch_models/test_validate_whole_models1.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
from onnx_diagnostic.tasks import supported_tasks
2525

2626

27-
torch29_and_tr_main = not has_torch("2.9.9") and has_transformers("4.99999")
28-
29-
3027
class TestValidateWholeModels1(ExtTestCase):
3128
def test_a_get_inputs_for_task(self):
3229
fcts = supported_tasks()
@@ -205,7 +202,6 @@ def test_k_filter_inputs(self):
205202
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["a"], model=["a", "b"])
206203
self.assertEqual((ni, nd), (((None,), {"b": 4}), {"b": 30}))
207204

208-
@unittest.skipIf(torch29_and_tr_main, "combination not working")
209205
@requires_torch("2.9.99")
210206
@hide_stdout()
211207
@ignore_warnings(FutureWarning)

_unittests/ut_torch_models/test_validate_whole_models2.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212
)
1313
from onnx_diagnostic.torch_models.validate import validate_model
1414

15-
torch29_and_tr_main = not has_torch("2.9.9") and has_transformers("4.99999")
16-
1715

1816
class TestValidateWholeModels2(ExtTestCase):
19-
@unittest.skipIf(torch29_and_tr_main, "combination not working")
2017
@requires_torch("2.9")
2118
@hide_stdout()
2219
@ignore_warnings(FutureWarning)

_unittests/ut_torch_models/test_validate_whole_models3.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,8 @@
1010
)
1111
from onnx_diagnostic.torch_models.validate import validate_model
1212

13-
torch29_and_tr_main = not has_torch("2.9.9") and has_transformers("4.99999")
14-
1513

1614
class TestValidateWholeModels3(ExtTestCase):
17-
@unittest.skipIf(torch29_and_tr_main, "combination not working")
1815
@requires_torch("2.7")
1916
@hide_stdout()
2017
@ignore_warnings(FutureWarning)

_unittests/ut_xrun_doc/test_command_lines.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,60 @@ def test_parser_validate(self):
6868
text = st.getvalue()
6969
self.assertIn("mid", text)
7070

71+
def test_parser_validate_cmd(self):
72+
parser = get_parser_validate()
73+
args = parser.parse_args(
74+
[
75+
"-m",
76+
"arnir0/Tiny-LLM",
77+
"--run",
78+
"-v",
79+
"1",
80+
"--mop",
81+
"cache_implementation=static",
82+
"--iop",
83+
"cls_cache=StaticCache",
84+
"--patch",
85+
]
86+
)
87+
self.assertEqual(args.mid, "arnir0/Tiny-LLM")
88+
self.assertEqual(args.run, True)
89+
self.assertEqual(args.patch, True)
90+
self.assertEqual(args.verbose, 1)
91+
self.assertEqual(args.mop, {"cache_implementation": "static"})
92+
self.assertEqual(args.iop, {"cls_cache": "StaticCache"})
93+
args = parser.parse_args(
94+
[
95+
"-m",
96+
"arnir0/Tiny-LLM",
97+
"--run",
98+
"-v",
99+
"1",
100+
"--mop",
101+
"cache_implementation=static",
102+
"--iop",
103+
"cls_cache=StaticCache",
104+
"--patch",
105+
"patch_sympy=False",
106+
"--patch",
107+
"patch_torch=False",
108+
]
109+
)
110+
self.assertEqual(args.mid, "arnir0/Tiny-LLM")
111+
self.assertEqual(args.run, True)
112+
self.assertEqual(
113+
args.patch,
114+
{
115+
"patch_diffusers": True,
116+
"patch_sympy": False,
117+
"patch_torch": False,
118+
"patch_transformers": True,
119+
},
120+
)
121+
self.assertEqual(args.verbose, 1)
122+
self.assertEqual(args.mop, {"cache_implementation": "static"})
123+
self.assertEqual(args.iop, {"cls_cache": "StaticCache"})
124+
71125
def test_parser_stats(self):
72126
st = StringIO()
73127
with redirect_stdout(st):
@@ -82,6 +136,26 @@ def test_parser_agg(self):
82136
text = st.getvalue()
83137
self.assertIn("--recent", text)
84138

139+
def test_parser_agg_cmd(self):
140+
parser = get_parser_agg()
141+
args = parser.parse_args(
142+
[
143+
"o.xlsx",
144+
"*.zip",
145+
"--sbs",
146+
"dynamo:exporter=onnx-dynamo,opt=ir,attn_impl=eager",
147+
"--sbs",
148+
"custom:exporter=custom,opt=default,attn_impl=eager",
149+
]
150+
)
151+
self.assertEqual(
152+
args.sbs,
153+
{
154+
"custom": {"attn_impl": "eager", "exporter": "custom", "opt": "default"},
155+
"dynamo": {"attn_impl": "eager", "exporter": "onnx-dynamo", "opt": "ir"},
156+
},
157+
)
158+
85159
def test_parser_sbs(self):
86160
st = StringIO()
87161
with redirect_stdout(st):

onnx_diagnostic/_command_lines_parser.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -517,12 +517,12 @@ def get_parser_validate(name: str = "validate") -> ArgumentParser:
517517
nargs="*",
518518
help=textwrap.dedent(
519519
"""
520-
Applies patches before exporting, it can be a boolean
521-
to enable to disable the patches or be more finetuned
522-
(default is True). It is possible to disable patch for torch
523-
by adding:
524-
--patch "patch_sympy=False" --patch "patch_torch=False"
525-
""".strip(
520+
Applies patches before exporting, it can be a boolean
521+
to enable to disable the patches or be more finetuned
522+
(default is True). It is possible to disable patch for torch
523+
by adding:
524+
--patch "patch_sympy=False" --patch "patch_torch=False"
525+
""".strip(
526526
"\n"
527527
)
528528
),

0 commit comments

Comments
 (0)