Skip to content

Commit 8bae9ab

Browse files
committed
fix
1 parent 57270e6 commit 8bae9ab

4 files changed

Lines changed: 31 additions & 3 deletions

File tree

_unittests/ut_tasks/try_export.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_qwen25_vli_visual(self):
5757
TESTDEVICE=cuda \\
5858
TESTDTYPE=float16 \\
5959
EXPORTER=custom \\
60+
CUT_EXPORTED_PROGRAM=qwen_sdpa_attention_loopmha_16 \\
6061
python _unittests/ut_tasks/try_export.py -k qwen25_vli_visual
6162
6263
.. code-block:: bash
@@ -78,6 +79,9 @@ def test_qwen25_vli_visual(self):
7879
"float32": torch.float32,
7980
}[dtype]
8081
exporter = os.environ.get("EXPORTER", "custom")
82+
cut_ep = os.environ.get("CUT_EXPORTED_PROGRAM", None)
83+
if cut_ep is not None:
84+
cut_ep = cut_ep.split(",")
8185

8286
from transformers import AutoModel, AutoProcessor
8387
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
@@ -135,15 +139,18 @@ def _config_reduction(config, task):
135139
grid_thw=torch.tensor([[1, 34, 38]], dtype=torch.int64).to(device),
136140
)
137141
if not self.unit_test_going():
138-
print("-- save inputs")
142+
print("-- save big inputs")
139143
torch.save(big_inputs, self.get_dump_file("qwen25_vli_visual.inputs.big.pt"))
140144
torch.save(inputs, self.get_dump_file("qwen25_vli_visual.inputs.pt"))
141145

142146
print(f"-- inputs: {self.string_type(inputs, with_shape=True)}")
143147
# this is too long
144148
model_to_export = model.visual if hasattr(model, "visual") else model.model.visual
145149
begin = time.perf_counter()
146-
expected = model_to_export(**inputs)
150+
if not os.environ.get("STOPAT", ""):
151+
expected = model_to_export(**inputs)
152+
else:
153+
expected = None
147154
print(f"-- MODEL RUN IN {time.perf_counter() - begin}")
148155
print(f"-- expected: {self.string_type(expected, with_shape=True)}")
149156

@@ -184,6 +191,8 @@ def _config_reduction(config, task):
184191
verbose=1,
185192
stop_if_static=2,
186193
):
194+
if expected is None:
195+
expected = model_to_export(**inputs)
187196
to_onnx(
188197
model_to_export,
189198
kwargs=export_inputs,
@@ -195,6 +204,7 @@ def _config_reduction(config, task):
195204
target_opset=24 if attention == "LOOPA24" else 22,
196205
optimize=True,
197206
onnx_plugs=PLUGS,
207+
cut_ep=cut_ep,
198208
)
199209

200210
if not self.unit_test_going():

onnx_diagnostic/_command_lines_parser.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,10 @@ def _size(name):
14371437
print("-- done")
14381438
del sess
14391439

1440+
if not args.sbs:
1441+
print("-- done")
1442+
return
1443+
14401444
print(f"-- load onnx {args.onnx!r}")
14411445
begin = time.perf_counter()
14421446
onx = onnx.load(args.onnx)

onnx_diagnostic/export/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def to_onnx(
2121
use_control_flow_dispatcher: bool = False,
2222
onnx_plugs: Optional[List[EagerDirectReplacementWithOnnx]] = None,
2323
inline: bool = True,
24+
cut_ep: Optional[List[str]] = None,
2425
) -> Any:
2526
"""
2627
Common API for exporters. By default, the models are optimized to use the
@@ -46,6 +47,8 @@ def to_onnx(
4647
custom loops (see :func:`onnx_diagnostic.export.control_flow_onnx.loop_for_onnx`)
4748
:param onnx_plugs: the code was modified to replace some parts with onnx translation
4849
:param inline: inline local functions
50+
:param cut_ep: cut the exported program before exporting,
51+
this is used to investigate issues.
4952
:return: the output of the selected exporter, usually a structure including
5053
an onnx model
5154
@@ -140,7 +143,7 @@ def find_method(self, name: Any):
140143
dynamic_shapes=dynamic_shapes,
141144
large_model=True,
142145
output_dynamic_shapes=output_dynamic_shapes,
143-
export_options=ExportOptions(save_ep=save_ep),
146+
export_options=ExportOptions(save_ep=save_ep, cut_ep=cut_ep),
144147
options=options,
145148
inline=inline,
146149
dispatcher=main_dispatcher,
@@ -155,6 +158,7 @@ def find_method(self, name: Any):
155158
assert (
156159
not output_dynamic_shapes
157160
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
161+
assert not cut_ep, f"cut_ep={cut_ep} not available with exporter={exporter!r}"
158162
custom_translation_table = {}
159163
if onnx_plugs:
160164
for plug in onnx_plugs:
@@ -218,6 +222,7 @@ def find_method(self, name: Any):
218222
f"Only a specified set of inputs is supported for exporter={exporter!r}, "
219223
f"but it is {list(kwargs)}" # type: ignore[arg-type]
220224
)
225+
assert not cut_ep, f"cut_ep={cut_ep} not available with exporter={exporter!r}"
221226
flat_inputs = flatten_object(kwargs, drop_keys=True)
222227
first = flat_inputs[0]
223228
first_float = [

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626
op = onnxscript.opset22
2727
op24 = onnxscript.onnx_opset.opset24
2828
msft_op = onnxscript.values.Opset("com.microsoft", 1)
29+
STOPAT = (
30+
int(os.environ.get("STOPAT", None))
31+
if os.environ.get("STOPAT", None) is not None
32+
else None
33+
)
2934

3035
def _add_com_microsoft_opset(function_proto: onnx.FunctionProto) -> onnx.FunctionProto:
3136
opsets = {d.domain: d.version for d in function_proto.opset_import}
@@ -529,8 +534,12 @@ def forward(
529534
position_embeddings=position_embeddings,
530535
**kwargs,
531536
)
537+
if STOPAT is not None and layer_num > STOPAT:
538+
break
532539

533540
hidden_states = self.merger(hidden_states)
541+
if STOPAT is not None:
542+
return hidden_states
534543
reverse_indices = torch.argsort(window_index)
535544
hidden_states = hidden_states[reverse_indices, :]
536545
return hidden_states

0 commit comments

Comments
 (0)