Skip to content

Commit da9cb89

Browse files
committed
support custom-tracing
1 parent 39977bd commit da9cb89

1 file changed

Lines changed: 9 additions & 0 deletions

File tree

onnx_diagnostic/torch_models/validate.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,6 +2330,7 @@ def call_torch_export_custom(
23302330
"custom-dec",
23312331
"custom-decall",
23322332
"custom-fake",
2333+
"custom-tracing",
23332334
}
23342335
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
23352336
assert "model" in data, f"model is missing from data: {sorted(data)}"
@@ -2342,11 +2343,16 @@ def call_torch_export_custom(
23422343
f"Options strict cannot be specified in the exporter name {exporter!r} "
23432344
f"and in the options {exporter_options}"
23442345
)
2346+
assert ("-tracing" not in exporter) or ("tracing" not in exporter_options), (
2347+
f"Options tracing cannot be specified in the exporter name {exporter!r} "
2348+
f"and in the options {exporter_options}"
2349+
)
23452350
summary: Dict[str, Union[str, int, float]] = {}
23462351
strict = "-strict" in exporter or exporter_options.pop("strict", False)
23472352
args, kwargs = split_args_kwargs(data["inputs_export"])
23482353
ds = data.get("dynamic_shapes", None)
23492354
fake = "-fake" in exporter or exporter_options.pop("fake", False)
2355+
tracing = "-tracing" in exporter or exporter_options.pop("tracing", False)
23502356
if fake:
23512357
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
23522358

@@ -2370,6 +2376,7 @@ def call_torch_export_custom(
23702376
summary["export_exporter"] = exporter
23712377
summary["export_optimization"] = optimization or ""
23722378
summary["export_strict"] = strict
2379+
summary["export_tracing"] = tracing
23732380
summary["export_fake"] = fake
23742381
summary["export_args"] = string_type(args, with_shape=True)
23752382
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
@@ -2392,6 +2399,7 @@ def call_torch_export_custom(
23922399
)
23932400
)
23942401
large_model = bool(exporter_options.pop("large_model", True))
2402+
exporter_options.pop("tracing", False)
23952403
return_optimize_report = bool(exporter_options.pop("return_optimize_report", True))
23962404
export_modules_as_functions = bool(
23972405
exporter_options.pop("export_modules_as_functions", False)
@@ -2405,6 +2413,7 @@ def call_torch_export_custom(
24052413
summary["export_external_threshold"] = str(external_threshold)
24062414

24072415
export_options = ExportOptions(
2416+
tracing=tracing,
24082417
strict=strict,
24092418
decomposition_table=decomposition_table,
24102419
save_ep=(

0 commit comments

Comments
 (0)