Skip to content

Commit 0aa4ce1

Browse files
committed
fix patch
1 parent f63ae39 commit 0aa4ce1

4 files changed

Lines changed: 150 additions & 113 deletions

File tree

_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_zero_shot_image_classification(self):
1717
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
1818
expected = model(**inputs)
1919
model(**data["inputs2"])
20-
with torch_export_patches(patch_transformers=True, verbose=10):
20+
with torch_export_patches(patch_torch=True, patch_transformers=True, verbose=10):
2121
ep = torch.export.export(
2222
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2323
)

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 118 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -677,157 +677,164 @@ def _get_seqlen(cls) -> torch.Tensor:
677677
@requires_cuda()
678678
def test_plug_multi_head_attention_qwen25_packed_float16(self):
679679
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
680-
qwen_sdpa_attention_packed_versatile,
680+
qwen_sdpa_attention_versatile as qwen_sdpa_attention_packed_versatile,
681681
)
682682

683-
inputs = (
684-
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
685-
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
686-
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
687-
self._get_seqlen().to("cuda"),
688-
)
683+
with self.set_env("QWEN25ATTENTION", "PACKED"):
684+
inputs = (
685+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
686+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
687+
torch.rand((1, 16, 1292, 80), dtype=torch.float16).to("cuda"),
688+
self._get_seqlen().to("cuda"),
689+
)
689690

690-
results = qwen_sdpa_attention_packed_versatile.verify(
691-
*inputs, scaling=0.5, num_heads=16
692-
)
693-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
694-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
695-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
696-
self.assertLess(results.diffs[0]["abs"], 0.01)
691+
results = qwen_sdpa_attention_packed_versatile.verify(
692+
*inputs, scaling=0.5, num_heads=16
693+
)
694+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
695+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
696+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
697+
self.assertLess(results.diffs[0]["abs"], 0.01)
697698

698-
results = qwen_sdpa_attention_packed_versatile.verify(
699-
*inputs, scaling=0.11180339887498948, num_heads=16
700-
)
701-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
702-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
703-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
704-
self.assertLess(results.diffs[0]["abs"], 0.01)
699+
results = qwen_sdpa_attention_packed_versatile.verify(
700+
*inputs, scaling=0.11180339887498948, num_heads=16
701+
)
702+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
703+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
704+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
705+
self.assertLess(results.diffs[0]["abs"], 0.01)
705706

706707
@requires_onnxruntime("1.25")
707708
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
708709
def test_plug_multi_head_attention_qwen25_loopmha_float16(self):
709710
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
710-
qwen_sdpa_attention_loopmha_versatile,
711+
qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopmha_versatile,
711712
)
712713

713-
inputs = (
714-
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
715-
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
716-
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
717-
self._get_seqlen(),
718-
)
714+
with self.set_env("QWEN25ATTENTION", "LOOPMHA"):
715+
inputs = (
716+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
717+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
718+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
719+
self._get_seqlen(),
720+
)
719721

720-
results = qwen_sdpa_attention_loopmha_versatile.verify(
721-
*inputs,
722-
scaling=0.5,
723-
num_heads=16,
724-
dump_onnx_model=self.get_dump_file(
725-
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
726-
),
727-
)
728-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
729-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
730-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
731-
self.assertLess(results.diffs[0]["abs"], 0.01)
722+
results = qwen_sdpa_attention_loopmha_versatile.verify(
723+
*inputs,
724+
scaling=0.5,
725+
num_heads=16,
726+
dump_onnx_model=self.get_dump_file(
727+
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
728+
),
729+
)
730+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
731+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
732+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
733+
self.assertLess(results.diffs[0]["abs"], 0.01)
732734

733-
results = qwen_sdpa_attention_loopmha_versatile.verify(
734-
*inputs, scaling=0.11180339887498948, num_heads=16
735-
)
736-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
737-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
738-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
739-
self.assertLess(results.diffs[0]["abs"], 0.01)
735+
results = qwen_sdpa_attention_loopmha_versatile.verify(
736+
*inputs, scaling=0.11180339887498948, num_heads=16
737+
)
738+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
739+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
740+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01)
741+
self.assertLess(results.diffs[0]["abs"], 0.01)
740742

741743
@requires_onnxruntime("1.25")
742744
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
743745
def test_plug_multi_head_attention_qwen25_loopmha_float32(self):
744746
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
745-
qwen_sdpa_attention_loopmha_versatile,
747+
qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopmha_versatile,
746748
)
747749

748-
inputs = (
749-
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
750-
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
751-
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
752-
self._get_seqlen(),
753-
)
750+
with self.set_env("QWEN25ATTENTION", "LOOPMHA"):
751+
inputs = (
752+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
753+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
754+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
755+
self._get_seqlen(),
756+
)
754757

755-
results = qwen_sdpa_attention_loopmha_versatile.verify(
756-
*inputs,
757-
scaling=0.5,
758-
num_heads=16,
759-
dump_onnx_model=self.get_dump_file(
760-
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
761-
),
762-
)
763-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
764-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
765-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
766-
self.assertLess(results.diffs[0]["abs"], 1e-5)
758+
results = qwen_sdpa_attention_loopmha_versatile.verify(
759+
*inputs,
760+
scaling=0.5,
761+
num_heads=16,
762+
dump_onnx_model=self.get_dump_file(
763+
"test_plug_packed_multi_head_attention_qwen25_loopmha_float16.onnx"
764+
),
765+
)
766+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
767+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
768+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
769+
self.assertLess(results.diffs[0]["abs"], 1e-5)
767770

768-
results = qwen_sdpa_attention_loopmha_versatile.verify(
769-
*inputs, scaling=0.11180339887498948, num_heads=16
770-
)
771-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
772-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
773-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
774-
self.assertLess(results.diffs[0]["abs"], 1e-5)
771+
results = qwen_sdpa_attention_loopmha_versatile.verify(
772+
*inputs, scaling=0.11180339887498948, num_heads=16
773+
)
774+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
775+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
776+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
777+
self.assertLess(results.diffs[0]["abs"], 1e-5)
775778

776779
@requires_onnxruntime("1.25")
777780
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
778781
def test_plug_multi_head_attention_qwen25_loopa24_float16(self):
779782
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
780-
qwen_sdpa_attention_loopa24_versatile,
783+
qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopa24_versatile,
781784
)
782785

783-
inputs = (
784-
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
785-
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
786-
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
787-
self._get_seqlen(),
788-
)
786+
with self.set_env("QWEN25ATTENTION", "LOOO24"):
787+
inputs = (
788+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
789+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
790+
torch.rand((1, 16, 1292, 80), dtype=torch.float16),
791+
self._get_seqlen(),
792+
)
789793

790-
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
791-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
792-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
793-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-2)
794-
self.assertLess(results.diffs[0]["abs"], 1e-2)
794+
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
795+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
796+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
797+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-2)
798+
self.assertLess(results.diffs[0]["abs"], 1e-2)
795799

796-
results = qwen_sdpa_attention_loopa24_versatile.verify(
797-
*inputs, scaling=0.11180339887498948
798-
)
799-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
800-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
801-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005)
802-
self.assertLess(results.diffs[0]["abs"], 0.005)
800+
results = qwen_sdpa_attention_loopa24_versatile.verify(
801+
*inputs, scaling=0.11180339887498948
802+
)
803+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
804+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
805+
self.assertEqualArray(
806+
results.eager_outputs[0], results.onnx_outputs[0], atol=0.005
807+
)
808+
self.assertLess(results.diffs[0]["abs"], 0.005)
803809

804810
@requires_onnxruntime("1.25")
805811
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
806812
def test_plug_multi_head_attention_qwen25_loopa24_float32(self):
807813
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
808-
qwen_sdpa_attention_loopa24_versatile,
814+
qwen_sdpa_attention_versatile as qwen_sdpa_attention_loopa24_versatile,
809815
)
810816

811-
inputs = (
812-
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
813-
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
814-
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
815-
self._get_seqlen(),
816-
)
817+
with self.set_env("QWEN25ATTENTION", "LOOO24"):
818+
inputs = (
819+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
820+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
821+
torch.rand((1, 16, 1292, 80), dtype=torch.float32),
822+
self._get_seqlen(),
823+
)
817824

818-
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
819-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
820-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
821-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
822-
self.assertLess(results.diffs[0]["abs"], 1e-5)
825+
results = qwen_sdpa_attention_loopa24_versatile.verify(*inputs, scaling=0.5)
826+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
827+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
828+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
829+
self.assertLess(results.diffs[0]["abs"], 1e-5)
823830

824-
results = qwen_sdpa_attention_loopa24_versatile.verify(
825-
*inputs, scaling=0.11180339887498948
826-
)
827-
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
828-
self.assertEqual(len(results.eager_outputs), len(results.diffs))
829-
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
830-
self.assertLess(results.diffs[0]["abs"], 1e-5)
831+
results = qwen_sdpa_attention_loopa24_versatile.verify(
832+
*inputs, scaling=0.11180339887498948
833+
)
834+
self.assertEqual(len(results.eager_outputs), len(results.onnx_outputs))
835+
self.assertEqual(len(results.eager_outputs), len(results.diffs))
836+
self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5)
837+
self.assertLess(results.diffs[0]["abs"], 1e-5)
831838

832839
@unittest.skipIf(not patch_funnel, "Funnel not part of this transformers")
833840
def test_model_funnel(self):

onnx_diagnostic/ext_test_case.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414
import unittest
1515
import warnings
16-
from contextlib import redirect_stderr, redirect_stdout
16+
from contextlib import redirect_stderr, redirect_stdout, contextmanager
1717
from io import StringIO
1818
from timeit import Timer
1919
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
@@ -1465,3 +1465,16 @@ def subloop(self, *args, verbose: int = 0):
14651465
if verbose:
14661466
print(f"[subloop] it={it!r}")
14671467
yield it
1468+
1469+
@contextmanager
1470+
def set_env(self, varname: str, value: str):
1471+
"""
1472+
Sets environment variable `varname` to `value`
1473+
and sets it back.
1474+
"""
1475+
old_value = os.environ.get(varname, None)
1476+
os.environ[varname] = value
1477+
try:
1478+
yield
1479+
finally:
1480+
os.environ[varname] = old_value or ""

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,25 @@ def patched__get_range_constraints(
158158
),
159159
len(export_graph_signature.input_specs),
160160
)
161+
161162
combined_args = torch.export._trace._combine_args(mod, args, kwargs)
162163

164+
# _combine_args does not preserve the order.
165+
if isinstance(combined_args, dict):
166+
input_names = [
167+
s.arg.name
168+
for s in export_graph_signature.input_specs
169+
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
170+
]
171+
new_args = {}
172+
for k in input_names:
173+
if k in combined_args:
174+
new_args[k] = combined_args[k]
175+
for k in combined_args:
176+
if k not in new_args:
177+
new_args[k] = combined_args[k]
178+
combined_args = new_args
179+
163180
range_constraints = torch._export.non_strict_utils.make_constraints(
164181
fake_mode, gm, combined_args, dynamic_shapes, num_lifted
165182
)

0 commit comments

Comments
 (0)