Skip to content

Commit 1cbe7f0

Browse files
committed
preliminiary support for submodule export
1 parent 0ee1179 commit 1cbe7f0

6 files changed

Lines changed: 91 additions & 11 deletions

File tree

_doc/cmds/validate.rst

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ of function :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`.
124124

125125
main("validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir --ortfusiontype ALL".split())
126126

127-
Sdpa or Eager implementation or Use a StaticCache
127+
SDPA or Eager implementation or Use a StaticCache
128128
+++++++++++++++++++++++++++++++++++++++++++++++++
129129

130130
Add ``--mop cache_implementation=static --iop cls_cache=StaticCache`` to use a StaticCache instead of a DynamicCache (default).
@@ -147,3 +147,22 @@ Add ``--mop attn_implementation=eager`` to explicitly select eager implementatio
147147
--mop attn_implementation=eager \
148148
--mop cache_implementation=static \
149149
--iop cls_cache=StaticCache
150+
151+
Frequent examples used to test
152+
++++++++++++++++++++++++++++++
153+
154+
.. code-block:: bash
155+
156+
python -m onnx_diagnostic validate -m arnir0/Tiny-LLM --run -v 1 --device cuda --dtype float16 -o dump_models --patch --opt default+onnxruntime --export custom
157+
158+
About the exporter 'custom'
159+
+++++++++++++++++++++++++++
160+
161+
It used to investigate issues or scenarios. It is usually very strict
162+
and fails everytime it falls in one unexpected situation.
163+
It call :func:`experimental_experiment.torch_interpreter.to_onnx`.
164+
Some useful environment variables to set before running the command line.
165+
166+
* ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply those patterns when optimizing a model
167+
* ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
168+
* ``PATTERN=<pattern1,pattern2,...>``: increase verbosity for specific patterns to understand why one pattern was not applied

_unittests/ut_tasks/test_tasks.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ def test_text_generation(self):
4747
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4848
)
4949

50+
@hide_stdout()
51+
def test_submodule(self):
52+
mid = "arnir0/Tiny-LLM::model"
53+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
54+
self.assertEqual(data["task"], "text-generation")
55+
self.assertIn("inputs", data)
56+
self.assertIn("inputs2", data)
57+
self.assertIn("inputs_batch1", data)
58+
self.assertIn("inputs_empty_cache", data)
59+
self.assertIn((data["size"], data["n_weights"]), [(27379968, 6844992)])
60+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
61+
model(**inputs)
62+
model(**data["inputs2"])
63+
with torch_export_patches(patch_transformers=True, verbose=10):
64+
torch.export.export(
65+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
66+
)
67+
5068
@hide_stdout()
5169
def test_text_generation_empty_cache(self):
5270
mid = "arnir0/Tiny-LLM"

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,12 +638,14 @@ def forward(
638638
self.config._attn_implementation
639639
]
640640

641-
is_sdpa = (
641+
is_sdpa_or_eager = (
642642
attention_interface
643643
is transformers.integrations.sdpa_attention.sdpa_attention_forward
644644
or attention_interface is patched_sdpa_attention_forward
645+
or attention_interface
646+
is transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.eager_attention_forward
645647
)
646-
if is_sdpa:
648+
if is_sdpa_or_eager:
647649
attn_output = qwen_sdpa_attention_versatile(
648650
query_states,
649651
key_states,

onnx_diagnostic/torch_models/code_sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def code_sample(
236236
)
237237
)
238238
"""
239-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
239+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
240240
model_id,
241241
subfolder,
242242
same_as_pretrained=same_as_pretrained,
@@ -256,6 +256,7 @@ def code_sample(
256256
model_kwargs=mop,
257257
subfolder=subfolder,
258258
add_second_input=False,
259+
submodule=submodule,
259260
)
260261
if drop_inputs:
261262
update = {}

onnx_diagnostic/torch_models/hghub/model_inputs.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,26 @@ def _code_needing_rewriting(model: Any) -> Any:
2626

2727

2828
def _preprocess_model_id(
29-
model_id: str, subfolder: Optional[str], same_as_pretrained: bool, use_pretrained: bool
30-
) -> Tuple[str, Optional[str], bool, bool]:
29+
model_id: str,
30+
subfolder: Optional[str],
31+
same_as_pretrained: bool,
32+
use_pretrained: bool,
33+
submodule: Optional[str] = None,
34+
) -> Tuple[str, Optional[str], bool, bool, Optional[str]]:
35+
if "::" in model_id:
36+
assert (
37+
not submodule
38+
), f"submodule={submodule!r} cannot be defined in model_id={model_id!r} as well"
39+
model_id, submodule = model_id.split("::", maxsplit=1)
3140
if subfolder or "//" not in model_id:
32-
return model_id, subfolder, same_as_pretrained, use_pretrained
41+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
3342
spl = model_id.split("//")
3443
if spl[-1] == "pretrained":
35-
return _preprocess_model_id("//".join(spl[:-1]), "", True, True)
44+
return _preprocess_model_id("//".join(spl[:-1]), "", True, True, submodule)
3645
if spl[-1] in {"transformer", "vae"}:
3746
# known subfolder
3847
return "//".join(spl[:-1]), spl[-1], same_as_pretrained, use_pretrained
39-
return model_id, subfolder, same_as_pretrained, use_pretrained
48+
return model_id, subfolder, same_as_pretrained, use_pretrained, submodule
4049

4150

4251
def get_untrained_model_with_inputs(
@@ -54,6 +63,7 @@ def get_untrained_model_with_inputs(
5463
subfolder: Optional[str] = None,
5564
use_only_preinstalled: bool = False,
5665
config_reduction: Optional[Callable[[Any, str], Dict]] = None,
66+
submodule: Optional[str] = None,
5767
) -> Dict[str, Any]:
5868
"""
5969
Gets a non initialized model similar to the original model
@@ -82,6 +92,7 @@ def get_untrained_model_with_inputs(
8292
<onnx_diagnostic.torch_models.hghub.reduce_model_config>`,
8393
this function takes a configuration and a task (string)
8494
as arguments
95+
:param submodule: use a submodule instead of the main model
8596
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
8697
some necessary rewriting as well
8798
@@ -108,11 +119,12 @@ def get_untrained_model_with_inputs(
108119
f"model_id={model_id!r}, preinstalled model is only available "
109120
f"if use_only_preinstalled is False."
110121
)
111-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
122+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
112123
model_id,
113124
subfolder,
114125
same_as_pretrained=same_as_pretrained,
115126
use_pretrained=use_pretrained,
127+
submodule=submodule,
116128
)
117129
if verbose:
118130
print(
@@ -147,6 +159,8 @@ def get_untrained_model_with_inputs(
147159
if verbose:
148160
print(f"[get_untrained_model_with_inputs] architecture={arch!r}")
149161
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
162+
if submodule:
163+
print(f"[get_untrained_model_with_inputs] submodule={submodule!r}")
150164
if task is None:
151165
task = task_from_arch(arch, model_id=model_id, subfolder=subfolder)
152166
if verbose:
@@ -357,6 +371,19 @@ def get_untrained_model_with_inputs(
357371
if diff_config is not None:
358372
res["dump_info"] = dict(config_diff=diff_config)
359373

374+
if submodule:
375+
path = submodule.split("::") if "::" in submodule else [submodule]
376+
for p in path:
377+
assert hasattr(model, p), (
378+
f"Unable to find submodule {p!r} in in class {type(model)}, "
379+
f"submodule={submodule!r}, possible candidates: "
380+
f"{[k for k in dir(model) if isinstance(getattr(model, k), torch.nn.Module)]}"
381+
)
382+
model = getattr(model, p)
383+
384+
if verbose:
385+
print(f"[get_untrained_model_with_inputs] model class={model.__class__.__name__!r}")
386+
360387
sizes = compute_model_size(model)
361388
res["model"] = model
362389
res["configuration"] = config

onnx_diagnostic/torch_models/validate.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,15 @@ def _prepare_validation(
349349
verbose,
350350
output_names,
351351
dump_folder,
352+
submodule,
352353
):
353354
main_validation_begin = time.perf_counter()
354-
model_id, subfolder, same_as_pretrained, use_pretrained = _preprocess_model_id(
355+
model_id, subfolder, same_as_pretrained, use_pretrained, submodule = _preprocess_model_id(
355356
model_id,
356357
subfolder,
357358
same_as_pretrained=same_as_pretrained,
358359
use_pretrained=use_pretrained,
360+
submodule=submodule,
359361
)
360362
time_preprocess_model_id = time.perf_counter() - main_validation_begin
361363
patch_kwargs = make_patch_kwargs(patch=patch, rewrite=rewrite)
@@ -364,6 +366,7 @@ def _prepare_validation(
364366
summary.update(
365367
dict(
366368
version_model_id=model_id,
369+
version_submodule=submodule,
367370
version_do_run=str(do_run),
368371
version_dtype=str(dtype or ""),
369372
version_device=str(device or ""),
@@ -444,6 +447,7 @@ def _prepare_validation(
444447
dump_folder,
445448
folder_name,
446449
patch_kwargs,
450+
submodule,
447451
)
448452

449453

@@ -460,6 +464,7 @@ def _get_untrained_model_with_inputs(
460464
inputs2,
461465
quiet,
462466
dump_folder,
467+
submodule,
463468
):
464469
iop = input_options or {}
465470
mop = model_options or {}
@@ -480,6 +485,7 @@ def _get_untrained_model_with_inputs(
480485
model_kwargs=mop,
481486
subfolder=sub,
482487
add_second_input=i2,
488+
submodule=submodule,
483489
)
484490
)
485491
),
@@ -842,6 +848,7 @@ def validate_model(
842848
ort_logs: bool = False,
843849
quiet_input_sets: Optional[Set[str]] = None,
844850
save_ep: Optional[str] = None,
851+
submodule: Optional[str] = None,
845852
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
846853
"""
847854
Validates a model.
@@ -902,6 +909,7 @@ def validate_model(
902909
even if quiet is False
903910
:param save_ep: if not empty, this can be used to save the input sets and
904911
the exported program
912+
:param submodule: to test not the model but a submodule of this model
905913
:return: two dictionaries, one with some metrics,
906914
another one with whatever the function produces
907915
@@ -966,6 +974,7 @@ def validate_model(
966974
use_pretrained=use_pretrained,
967975
same_as_pretrained=same_as_pretrained,
968976
save_ep=save_ep,
977+
submodule=submodule,
969978
)
970979
if dump_folder:
971980
with open(dump_stats, "w") as f:
@@ -1053,6 +1062,7 @@ def _validate_model_step1(
10531062
use_pretrained,
10541063
same_as_pretrained,
10551064
save_ep,
1065+
submodule,
10561066
):
10571067
assert not do_same or do_run, (
10581068
f"Discrepancies cannot be measured if the model is not run, "
@@ -1067,6 +1077,7 @@ def _validate_model_step1(
10671077
dump_folder,
10681078
folder_name,
10691079
patch_kwargs,
1080+
submodule,
10701081
) = _prepare_validation(
10711082
model_id=model_id,
10721083
subfolder=subfolder,
@@ -1093,6 +1104,7 @@ def _validate_model_step1(
10931104
verbose=verbose,
10941105
output_names=output_names,
10951106
dump_folder=dump_folder,
1107+
submodule=submodule,
10961108
)
10971109

10981110
data, iop, mop = _get_untrained_model_with_inputs(
@@ -1108,6 +1120,7 @@ def _validate_model_step1(
11081120
inputs2=inputs2,
11091121
quiet=quiet,
11101122
dump_folder=dump_folder,
1123+
submodule=submodule,
11111124
)
11121125

11131126
second_input_keys = [k for k in data if k.startswith("inputs") and k != "inputs"]

0 commit comments

Comments
 (0)