@@ -57,6 +57,7 @@ def test_qwen25_vli_visual(self):
5757 TESTDEVICE=cuda \\
5858 TESTDTYPE=float16 \\
5959 EXPORTER=custom \\
60+ CUT_EXPORTED_PROGRAM=qwen_sdpa_attention_loopmha_16 \\
6061 python _unittests/ut_tasks/try_export.py -k qwen25_vli_visual
6162
6263 .. code-block:: bash
@@ -78,6 +79,9 @@ def test_qwen25_vli_visual(self):
7879 "float32" : torch .float32 ,
7980 }[dtype ]
8081 exporter = os .environ .get ("EXPORTER" , "custom" )
82+ cut_ep = os .environ .get ("CUT_EXPORTED_PROGRAM" , None )
83+ if cut_ep is not None :
84+ cut_ep = cut_ep .split ("," )
8185
8286 from transformers import AutoModel , AutoProcessor
8387 from onnx_diagnostic .torch_export_patches .patches ._patch_transformers_qwen2_5 import (
@@ -135,15 +139,18 @@ def _config_reduction(config, task):
135139 grid_thw = torch .tensor ([[1 , 34 , 38 ]], dtype = torch .int64 ).to (device ),
136140 )
137141 if not self .unit_test_going ():
138- print ("-- save inputs" )
142+ print ("-- save big inputs" )
139143 torch .save (big_inputs , self .get_dump_file ("qwen25_vli_visual.inputs.big.pt" ))
140144 torch .save (inputs , self .get_dump_file ("qwen25_vli_visual.inputs.pt" ))
141145
142146 print (f"-- inputs: { self .string_type (inputs , with_shape = True )} " )
143147 # this is too long
144148 model_to_export = model .visual if hasattr (model , "visual" ) else model .model .visual
145149 begin = time .perf_counter ()
146- expected = model_to_export (** inputs )
150+ if not os .environ .get ("STOPAT" , "" ):
151+ expected = model_to_export (** inputs )
152+ else :
153+ expected = None
147154 print (f"-- MODEL RUN IN { time .perf_counter () - begin } " )
148155 print (f"-- expected: { self .string_type (expected , with_shape = True )} " )
149156
@@ -184,6 +191,8 @@ def _config_reduction(config, task):
184191 verbose = 1 ,
185192 stop_if_static = 2 ,
186193 ):
194+ if expected is None :
195+ expected = model_to_export (** inputs )
187196 to_onnx (
188197 model_to_export ,
189198 kwargs = export_inputs ,
@@ -195,6 +204,7 @@ def _config_reduction(config, task):
195204 target_opset = 24 if attention == "LOOPA24" else 22 ,
196205 optimize = True ,
197206 onnx_plugs = PLUGS ,
207+ cut_ep = cut_ep ,
198208 )
199209
200210 if not self .unit_test_going ():
0 commit comments