Skip to content

Commit 02eb10c

Browse files
committed
fix
1 parent 5035092 commit 02eb10c

2 files changed

Lines changed: 3 additions & 6 deletions

File tree

onnx_diagnostic/investigate/input_observer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,8 @@ def _post_process_for_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
677677
# Nothing to do here.
678678
return kwargs
679679
to_be_moved = {k for k in kwargs if k not in self.signature_names}
680+
if not to_be_moved:
681+
return kwargs
680682
keywords = {k: v for k, v in kwargs.items() if k in to_be_moved}
681683
new_kwargs = {k: v for k, v in kwargs.items() if k not in to_be_moved}
682684
return {**new_kwargs, self.kwargs_name: keywords}

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,8 @@ def patched__get_range_constraints(
165165
assert isinstance(
166166
combined_args, dict
167167
), f"unexpected type {type(combined_args)} for 'combined_args'"
168-
input_names = [
169-
s.arg.name
170-
for s in export_graph_signature.input_specs
171-
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
172-
]
173168
new_args = {}
174-
for k in input_names:
169+
for k in kwargs:
175170
if k in combined_args:
176171
new_args[k] = combined_args[k]
177172
for k in combined_args:

0 commit comments

Comments
 (0)