11import unittest
22import torch
3- from onnx_diagnostic .ext_test_case import ExtTestCase , requires_torch , has_torch
3+ from onnx_diagnostic .ext_test_case import ExtTestCase , requires_torch
44from onnx_diagnostic .helpers .torch_helper import (
55 is_torchdynamo_exporting ,
66 fake_torchdynamo_exporting ,
1111 register_patched_expressions ,
1212 patched_float_arange ,
1313)
14+ from onnx_diagnostic .torch_export_patches import torch_export_patches
1415
1516
1617class TestOnnxExportErrors (ExtTestCase ):
@@ -20,9 +21,23 @@ def test_patched_expressions(self):
2021 names = {_ [0 ] for _ in res }
2122 self .assertIn ("float_arange" , names )
2223
23- @requires_torch ("2.8" )
24- def test_filter_position_ids (self ):
24+ def test_float_arange (self ):
25+ register_patched_expressions ()
26+ rg = torch .arange (0.0 , 0.99 , 0.1 )
27+ rg2 = torch .ops .patched .float_arange (
28+ torch .tensor (0.0 ), torch .tensor (0.99 ), torch .tensor (0.1 )
29+ )
30+ rg3 = patched_float_arange (torch .tensor (0.0 ), torch .tensor (0.99 ), torch .tensor (0.1 ))
31+ self .assertEqualArray (rg , rg2 , atol = 1e-5 )
32+ self .assertEqualArray (rg , rg3 , atol = 1e-5 )
33+ with fake_torchdynamo_exporting ():
34+ rg4 = patched_float_arange (
35+ torch .tensor (0.0 ), torch .tensor (0.99 ), torch .tensor (0.1 )
36+ )
37+ self .assertEqualArray (rg , rg4 , atol = 1e-5 )
2538
39+ @requires_torch ("2.9.99" )
40+ def test_filter_position_ids (self ):
2641 def filter_position_ids (
2742 patch_attention_mask : torch .Tensor ,
2843 position_ids : torch .Tensor ,
@@ -42,15 +57,6 @@ def filter_position_ids(
4257 position_ids [batch_idx ][p_attn_mask .view (- 1 )] = pos_ids
4358 return position_ids
4459
45- def float_arange (start , end , step ):
46- length = torch .sym_int ((end - start ) / step + (step * (1 - 1e-6 )))
47- torch ._check (length > 0 )
48- res = torch .arange (0 , length )
49- torch ._check (res .is_contiguous ())
50- fres = res .to (torch .float32 )
51- fstart = torch .tensor (start , dtype = torch .float32 )
52- return fres + fstart
53-
5460 def scan_filter_position_ids (
5561 patch_attention_mask : torch .Tensor ,
5662 position_ids : torch .Tensor ,
@@ -59,18 +65,21 @@ def scan_filter_position_ids(
5965 ):
6066
6167 def body (p_attn_mask , position_ids_row ):
62- h_len = torch .tensor (1 ) / p_attn_mask [:, 0 ].sum ()
63- w_len = torch .tensor (1 ) / p_attn_mask [0 ].sum ()
64- fractional_coords_h = patched_float_arange (
65- torch .tensor (0.0 ), torch .tensor (1 - 1e-6 ), h_len
68+ h_len = torch .tensor (1 , dtype = p_attn_mask .dtype ) / p_attn_mask [:, 0 ].sum ()
69+ w_len = torch .tensor (1 , dtype = p_attn_mask .dtype ) / p_attn_mask [0 ].sum ()
70+ torch ._check (h_len .item () > 0 )
71+ fractional_coords_h = torch .arange (
72+ torch .tensor (0.0 , dtype = p_attn_mask .dtype ),
73+ torch .tensor (1 - 1e-6 , dtype = p_attn_mask .dtype ),
74+ h_len ,
6675 )
67- fractional_coords_w = patched_float_arange (
68- torch .tensor (0.0 ), torch .tensor (1 - 1e-6 ), w_len
76+ torch ._check (w_len .item () > 0 )
77+ fractional_coords_w = torch .arange (
78+ torch .tensor (0.0 , dtype = p_attn_mask .dtype ),
79+ torch .tensor (1 - 1e-6 , dtype = p_attn_mask .dtype ),
80+ w_len ,
6981 )
7082
71- # torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum().item())
72- # torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum().item())
73-
7483 bucket_coords_h = torch .bucketize (fractional_coords_h , boundaries , right = True )
7584 bucket_coords_w = torch .bucketize (fractional_coords_w , boundaries , right = True )
7685
@@ -116,17 +125,12 @@ def forward(self, patch_attention_mask, position_ids, boundaries):
116125 self .assertEqualArray (expected , got )
117126
118127 DYN = torch .export .Dim .DYNAMIC
119- ep = torch .export .export (model , inputs , dynamic_shapes = ({0 : DYN }, {0 : DYN }, {0 : DYN }))
120- try :
121- got = ep .module ()(* inputs )
122- except Exception :
123- # At least it exports, we need to remove the assert from the exported program.
124- # Let's revisit this later.
125- if has_torch ("2.11" ):
126- raise
127- got = None
128- if got is not None :
129- self .assertEqualArray (expected , got )
128+ with torch_export_patches (patch_torch = True ):
129+ ep = torch .export .export (
130+ model , inputs , dynamic_shapes = ({0 : DYN }, {0 : DYN }, {0 : DYN })
131+ )
132+ got = ep .module ()(* inputs )
133+ self .assertEqualArray (expected , got )
130134
131135
132136if __name__ == "__main__" :
0 commit comments