Skip to content

Commit 10612a3

Browse files
authored
Add patch for DynamicDimConstraintPrinter (#363)
* Add patch for DynamicDimConstraintPrinter * fix for versions * fix * doc
1 parent c549cfe commit 10612a3

7 files changed

Lines changed: 170 additions & 33 deletions

File tree

CHANGELOGS.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ Change Logs
44
0.8.7
55
+++++
66

7+
* :pr:`363`: patch for DynamicDimConstraintPrinter
8+
* :pr:`360`: preliminary work for phi4
9+
710
0.8.6
811
+++++
912

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ def linkcode_resolve(domain, info):
134134
("py:class", "onnx_ir.Tuple"),
135135
("py:class", "pandas.core.groupby.generic.DataFrameGroupBy"),
136136
("py:class", "pipeline.Pipeline"),
137+
("py:class", "torch._guards.Source"),
138+
("py:class", "torch._ops.HigherOrderOperator"),
137139
("py:class", "torch.fx.passes.operator_support.OperatorSupport"),
138140
("py:class", "torch.fx.proxy.TracerBase"),
139141
("py:class", "torch.FloatTensor"),

_unittests/ut_torch_export_patches/test_patch_loops.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import torch
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, has_torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
44
from onnx_diagnostic.helpers.torch_helper import (
55
is_torchdynamo_exporting,
66
fake_torchdynamo_exporting,
@@ -11,6 +11,7 @@
1111
register_patched_expressions,
1212
patched_float_arange,
1313
)
14+
from onnx_diagnostic.torch_export_patches import torch_export_patches
1415

1516

1617
class TestOnnxExportErrors(ExtTestCase):
@@ -20,9 +21,23 @@ def test_patched_expressions(self):
2021
names = {_[0] for _ in res}
2122
self.assertIn("float_arange", names)
2223

23-
@requires_torch("2.8")
24-
def test_filter_position_ids(self):
24+
def test_float_arange(self):
25+
register_patched_expressions()
26+
rg = torch.arange(0.0, 0.99, 0.1)
27+
rg2 = torch.ops.patched.float_arange(
28+
torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1)
29+
)
30+
rg3 = patched_float_arange(torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1))
31+
self.assertEqualArray(rg, rg2, atol=1e-5)
32+
self.assertEqualArray(rg, rg3, atol=1e-5)
33+
with fake_torchdynamo_exporting():
34+
rg4 = patched_float_arange(
35+
torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1)
36+
)
37+
self.assertEqualArray(rg, rg4, atol=1e-5)
2538

39+
@requires_torch("2.9.99")
40+
def test_filter_position_ids(self):
2641
def filter_position_ids(
2742
patch_attention_mask: torch.Tensor,
2843
position_ids: torch.Tensor,
@@ -42,15 +57,6 @@ def filter_position_ids(
4257
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
4358
return position_ids
4459

45-
def float_arange(start, end, step):
46-
length = torch.sym_int((end - start) / step + (step * (1 - 1e-6)))
47-
torch._check(length > 0)
48-
res = torch.arange(0, length)
49-
torch._check(res.is_contiguous())
50-
fres = res.to(torch.float32)
51-
fstart = torch.tensor(start, dtype=torch.float32)
52-
return fres + fstart
53-
5460
def scan_filter_position_ids(
5561
patch_attention_mask: torch.Tensor,
5662
position_ids: torch.Tensor,
@@ -59,18 +65,21 @@ def scan_filter_position_ids(
5965
):
6066

6167
def body(p_attn_mask, position_ids_row):
62-
h_len = torch.tensor(1) / p_attn_mask[:, 0].sum()
63-
w_len = torch.tensor(1) / p_attn_mask[0].sum()
64-
fractional_coords_h = patched_float_arange(
65-
torch.tensor(0.0), torch.tensor(1 - 1e-6), h_len
68+
h_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[:, 0].sum()
69+
w_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[0].sum()
70+
torch._check(h_len.item() > 0)
71+
fractional_coords_h = torch.arange(
72+
torch.tensor(0.0, dtype=p_attn_mask.dtype),
73+
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
74+
h_len,
6675
)
67-
fractional_coords_w = patched_float_arange(
68-
torch.tensor(0.0), torch.tensor(1 - 1e-6), w_len
76+
torch._check(w_len.item() > 0)
77+
fractional_coords_w = torch.arange(
78+
torch.tensor(0.0, dtype=p_attn_mask.dtype),
79+
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
80+
w_len,
6981
)
7082

71-
# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum().item())
72-
# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum().item())
73-
7483
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
7584
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
7685

@@ -116,17 +125,12 @@ def forward(self, patch_attention_mask, position_ids, boundaries):
116125
self.assertEqualArray(expected, got)
117126

118127
DYN = torch.export.Dim.DYNAMIC
119-
ep = torch.export.export(model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN}))
120-
try:
121-
got = ep.module()(*inputs)
122-
except Exception:
123-
# At least it exports, we need to remove the assert from the exported program.
124-
# Let's revisit this later.
125-
if has_torch("2.11"):
126-
raise
127-
got = None
128-
if got is not None:
129-
self.assertEqualArray(expected, got)
128+
with torch_export_patches(patch_torch=True):
129+
ep = torch.export.export(
130+
model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN})
131+
)
132+
got = ep.module()(*inputs)
133+
self.assertEqualArray(expected, got)
130134

131135

132136
if __name__ == "__main__":

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,91 @@ def forward(self, x, y):
579579
shape = output[0].args[0][0].meta["val"].shape
580580
self.assertEqual(str(shape), "torch.Size([Max(s17, s77)])")
581581

582+
@requires_torch("2.9.99")
583+
def test_patched_DynamicDimConstraintPrinter(self):
584+
def filter_position_ids(
585+
patch_attention_mask: torch.Tensor,
586+
position_ids: torch.Tensor,
587+
boundaries: torch.Tensor,
588+
num_patches_per_side: int,
589+
):
590+
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
591+
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum())
592+
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum())
593+
594+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
595+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
596+
597+
pos_ids = (
598+
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
599+
).flatten()
600+
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
601+
return position_ids
602+
603+
def scan_filter_position_ids(
604+
patch_attention_mask: torch.Tensor,
605+
position_ids: torch.Tensor,
606+
boundaries: torch.Tensor,
607+
num_patches_per_side: int,
608+
):
609+
610+
def body(p_attn_mask, position_ids_row):
611+
h_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[:, 0].sum()
612+
w_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[0].sum()
613+
torch._check(h_len.item() > 0)
614+
fractional_coords_h = torch.arange(
615+
torch.tensor(0.0, dtype=p_attn_mask.dtype),
616+
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
617+
h_len,
618+
)
619+
torch._check(w_len.item() > 0)
620+
fractional_coords_w = torch.arange(
621+
torch.tensor(0.0, dtype=p_attn_mask.dtype),
622+
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
623+
w_len,
624+
)
625+
626+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
627+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
628+
629+
pos_ids = (
630+
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
631+
).flatten()
632+
633+
row = position_ids_row.clone()
634+
row[p_attn_mask.view(-1)] = pos_ids
635+
return [row]
636+
637+
return torch.ops.higher_order.scan(
638+
body, [], [patch_attention_mask, position_ids], additional_inputs=[]
639+
)
640+
641+
class Model(torch.nn.Module):
642+
def forward(self, patch_attention_mask, position_ids, boundaries):
643+
res = scan_filter_position_ids(
644+
patch_attention_mask, position_ids, boundaries, 32
645+
)
646+
return res[0]
647+
648+
patch_attention_mask = torch.randint(0, 17, (32, 32, 32)) >= 1
649+
patch_attention_mask[:, :, :] = True
650+
position_ids = torch.zeros((32, 1024), dtype=torch.int64)
651+
boundaries = (torch.arange(33).to(torch.float32) / 33)[1:-1]
652+
inputs = (patch_attention_mask, position_ids, boundaries)
653+
654+
model = Model()
655+
true_expected = filter_position_ids(*(*inputs, 32))
656+
expected = model(*inputs)
657+
self.assertEqualArray(true_expected, expected)
658+
659+
DYN = torch.export.Dim.DYNAMIC
660+
with torch_export_patches(patch_torch=True):
661+
ep = torch.export.export(
662+
model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN})
663+
)
664+
got = ep.module()(*inputs)
665+
self.assertEqualArray(expected, got)
666+
582667

583668
if __name__ == "__main__":
584669
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ def _patch_torch(
221221
catch_constraints: bool,
222222
stop_if_static: int,
223223
) -> Tuple[Optional[Callable], ...]:
224+
import packaging.version as pv
224225
import torch
225226
import torch.jit
226227
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
@@ -238,6 +239,11 @@ def _patch_torch(
238239
patched_ShapeEnv,
239240
)
240241

242+
if pv.Version(torch.__version__) >= pv.Version("2.9.99"):
243+
from .patches.patch_torch import patched_DynamicDimConstraintPrinter
244+
else:
245+
patched_DynamicDimConstraintPrinter = None
246+
241247
f___constrain_user_specified_dimhint_range = None
242248
f__broadcast_in_dim_meta = None
243249
f__broadcast_shapes = None
@@ -259,6 +265,17 @@ def _patch_torch(
259265
print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
260266
print("[torch_export_patches] patch pytorch")
261267

268+
# torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
269+
if patched_DynamicDimConstraintPrinter is not None:
270+
f__print_symbol = (
271+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
272+
)
273+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
274+
patched_DynamicDimConstraintPrinter._print_Symbol
275+
)
276+
else:
277+
f__print_symbol = None
278+
262279
# torch.vmap
263280
f_vmap = torch.vmap
264281
torch.vmap = patched_vmap
@@ -392,6 +409,7 @@ def _patch_torch(
392409
f_shape_env__log_guard,
393410
f_shape_env__set_replacement,
394411
f_vmap,
412+
f__print_symbol,
395413
)
396414

397415

@@ -416,13 +434,18 @@ def _unpatch_torch(
416434
f_shape_env__log_guard: Optional[Callable],
417435
f_shape_env__set_replacement: Optional[Callable],
418436
f_vmap: Optional[Callable],
437+
f__print_symbol: Optional[Callable],
419438
):
420439
import torch
421440
import torch.jit
422441
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
423442
from torch.fx.experimental.symbolic_shapes import ShapeEnv
424443

425444
# this should disappear when torch.jit is removed
445+
if f__print_symbol is not None:
446+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
447+
f__print_symbol
448+
)
426449
torch.vmap = f_vmap
427450
torch.jit.isinstance = f_jit_isinstance
428451
torch._dynamo.mark_static_address = f_mark_static_address
@@ -992,6 +1015,7 @@ def torch_export_patches(
9921015
f_shape_env__log_guard,
9931016
f_shape_env__set_replacement,
9941017
f_vmap,
1018+
f__print_Symbol,
9951019
) = _patch_torch(
9961020
verbose, patch_details, patch_torch, catch_constraints, stop_if_static
9971021
)
@@ -1067,6 +1091,7 @@ def torch_export_patches(
10671091
f_shape_env__log_guard,
10681092
f_shape_env__set_replacement,
10691093
f_vmap,
1094+
f__print_Symbol,
10701095
)
10711096

10721097
if patch_transformers:

onnx_diagnostic/torch_export_patches/patch_expressions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:
101101

102102

103103
def patched_float_arange(start, end, step):
104-
"""Patched arange when start, end, step are floats."""
104+
"""
105+
Patched arange when start, end, step are floats.
106+
This patch should not be needed after 2.10.
107+
"""
105108
if is_torchdynamo_exporting():
106109
return torch.ops.patched.float_arange(start, end, step)
107110
else:

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import traceback
66
from functools import reduce
77
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
8+
import sympy
89
import torch
910
from torch._subclasses.fake_tensor import FakeTensorMode
1011

@@ -1091,3 +1092,17 @@ def _greater_than_reduce(acc, x):
10911092
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
10921093

10931094
return a.as_strided(shape, new_strides, a.storage_offset())
1095+
1096+
1097+
class patched_DynamicDimConstraintPrinter:
1098+
"""
1099+
Patches
1100+
``torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol``.
1101+
Valid for ``torch>=2.10``.
1102+
"""
1103+
1104+
def _print_Symbol(self, expr: sympy.Symbol) -> str:
1105+
assert isinstance(expr, sympy.Symbol), str(type(expr))
1106+
if self.symbol_to_source.get(expr):
1107+
return self.symbol_to_source[expr][0].name
1108+
return str(expr)

0 commit comments

Comments
 (0)