Skip to content

Commit 5035092

Browse files
committed
fix
1 parent 0aa4ce1 commit 5035092

3 files changed

Lines changed: 26 additions & 15 deletions

File tree

_unittests/ut_investigate/test_input_observer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
import unittest
33
import pandas
44
import torch
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
5+
from onnx_diagnostic.ext_test_case import (
6+
ExtTestCase,
7+
requires_torch,
8+
hide_stdout,
9+
ignore_warnings,
10+
)
611
from onnx_diagnostic.investigate.input_observer import (
712
InputObserver,
813
_infer_dynamic_dimensions,
@@ -816,6 +821,8 @@ def forward(self, x=None, y=None):
816821
self.assertEqual(2, len(args))
817822
self.assertEqual(len([v for v in args.values() if v is not None]), 2)
818823

824+
@hide_stdout()
825+
@ignore_warnings(FutureWarning)
819826
def test_io_int_kwargs(self):
820827
class Model(torch.nn.Module):
821828
def forward(self, x=None, y=None, option=1):

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
requires_torch,
1313
ignore_warnings,
1414
has_onnxscript,
15+
requires_onnxscript,
1516
)
1617
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting
1718
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
@@ -352,6 +353,7 @@ def forward(self, query, key, value):
352353
self.assertEqualArray(expected, got)
353354

354355
@requires_transformers("4.55")
356+
@requires_onnxscript("0.6.2")
355357
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
356358
def test_qwen_apply_multimodal_rotary_pos_emb(self):
357359
apply_multimodal_rotary_pos_emb = (

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -162,20 +162,22 @@ def patched__get_range_constraints(
162162
combined_args = torch.export._trace._combine_args(mod, args, kwargs)
163163

164164
# _combine_args does not preserve the order.
165-
if isinstance(combined_args, dict):
166-
input_names = [
167-
s.arg.name
168-
for s in export_graph_signature.input_specs
169-
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
170-
]
171-
new_args = {}
172-
for k in input_names:
173-
if k in combined_args:
174-
new_args[k] = combined_args[k]
175-
for k in combined_args:
176-
if k not in new_args:
177-
new_args[k] = combined_args[k]
178-
combined_args = new_args
165+
assert isinstance(
166+
combined_args, dict
167+
), f"unexpected type {type(combined_args)} for 'combined_args'"
168+
input_names = [
169+
s.arg.name
170+
for s in export_graph_signature.input_specs
171+
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
172+
]
173+
new_args = {}
174+
for k in input_names:
175+
if k in combined_args:
176+
new_args[k] = combined_args[k]
177+
for k in combined_args:
178+
if k not in new_args:
179+
new_args[k] = combined_args[k]
180+
combined_args = new_args
179181

180182
range_constraints = torch._export.non_strict_utils.make_constraints(
181183
fake_mode, gm, combined_args, dynamic_shapes, num_lifted

0 commit comments

Comments
 (0)