Skip to content

Commit 5906a90

Browse files
committed
fix again
1 parent 02eb10c commit 5906a90

1 file changed

Lines changed: 15 additions & 9 deletions

File tree

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,18 +161,24 @@ def patched__get_range_constraints(
161161

162162
combined_args = torch.export._trace._combine_args(mod, args, kwargs)
163163

164-
# _combine_args does not preserve the order.
164+
# This is because we trace based on the kwargs passed in from user
165+
# not based on the signature. I feel it would be better to just enforce
166+
# one ordering at the start of tracing to avoid confusions, but that is
167+
# bigger refactor, so do this to unblock for now.
165168
assert isinstance(
166169
combined_args, dict
167170
), f"unexpected type {type(combined_args)} for 'combined_args'"
168-
new_args = {}
169-
for k in kwargs:
170-
if k in combined_args:
171-
new_args[k] = combined_args[k]
172-
for k in combined_args:
173-
if k not in new_args:
174-
new_args[k] = combined_args[k]
175-
combined_args = new_args
171+
172+
combined_args_traced_order = {}
173+
for arg in kwargs:
174+
if arg in combined_args:
175+
combined_args_traced_order[arg] = combined_args[arg]
176+
177+
for key in combined_args:
178+
if key not in combined_args_traced_order:
179+
combined_args_traced_order[key] = combined_args[key]
180+
181+
combined_args = combined_args_traced_order
176182

177183
range_constraints = torch._export.non_strict_utils.make_constraints(
178184
fake_mode, gm, combined_args, dynamic_shapes, num_lifted

0 commit comments

Comments
 (0)