@@ -161,18 +161,24 @@ def patched__get_range_constraints(
161161
162162 combined_args = torch .export ._trace ._combine_args (mod , args , kwargs )
163163
164- # _combine_args does not preserve the order.
164+ # This is because we trace based on the kwargs passed in from user
165+ # not based on the signature. I feel it would be better to just enforce
166+ # one ordering at the start of tracing to avoid confusions, but that is
167+ # bigger refactor, so do this to unblock for now.
165168 assert isinstance (
166169 combined_args , dict
167170 ), f"unexpected type { type (combined_args )} for 'combined_args'"
168- new_args = {}
169- for k in kwargs :
170- if k in combined_args :
171- new_args [k ] = combined_args [k ]
172- for k in combined_args :
173- if k not in new_args :
174- new_args [k ] = combined_args [k ]
175- combined_args = new_args
171+
172+ combined_args_traced_order = {}
173+ for arg in kwargs :
174+ if arg in combined_args :
175+ combined_args_traced_order [arg ] = combined_args [arg ]
176+
177+ for key in combined_args :
178+ if key not in combined_args_traced_order :
179+ combined_args_traced_order [key ] = combined_args [key ]
180+
181+ combined_args = combined_args_traced_order
176182
177183 range_constraints = torch ._export .non_strict_utils .make_constraints (
178184 fake_mode , gm , combined_args , dynamic_shapes , num_lifted
0 commit comments