@@ -189,7 +189,7 @@ def _trace_and_compile(
189189 mapping : torch .Tensor ,
190190 fparam : torch .Tensor | None ,
191191 aparam : torch .Tensor | None ,
192- compile_opts : dict [str , Any ],
192+ compile_opts : dict [str , Any ] | None = None ,
193193) -> torch .nn .Module :
194194 """Symbolic-trace ``forward_lower`` and compile with inductor + dynamic=True.
195195
@@ -199,9 +199,9 @@ def _trace_and_compile(
199199 The (uncompiled) model.
200200 ext_coord, ext_atype, nlist, mapping, fparam, aparam
201201 Sample tensors used to seed the symbolic tracer.
202- compile_opts : dict
203- Options forwarded to ``torch.compile`` (the ``dynamic`` and
204- ``backend`` keys are ignored and replaced ).
202+ compile_opts : dict or None
203+ User-supplied inductor options. These are merged on top of the
204+ built-in defaults (user values take precedence ).
205205
206206 Returns
207207 -------
@@ -296,22 +296,28 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
296296 if not was_training :
297297 model .eval ()
298298
299- # Work on a copy; ignore caller-supplied dynamic/backend.
300- compile_opts = {
301- k : v for k , v in compile_opts .items () if k not in ("dynamic" , "backend" )
299+ # Inductor defaults tuned for second-order-gradient training graphs.
300+ # User-supplied compile_opts override these on a per-key basis.
301+ inductor_options : dict [str , Any ] = {
302+ "max_autotune" : False ,
303+ "shape_padding" : True ,
304+ "epilogue_fusion" : False ,
305+ "triton.cudagraphs" : False ,
306+ "max_fusion_size" : 64 ,
307+ # NOTE: mix_order_reduction hits multiple bugs under
308+ # data-dependent symbolic shapes on PyTorch <=2.11
309+ # (pytorch/pytorch#174379, #178080, #179494) -- our
310+ # edge count is exactly that kind of shape.
311+ "triton.mix_order_reduction" : False ,
302312 }
303- opts = compile_opts .setdefault ("options" , {})
304- opts .setdefault ("max_autotune" , False )
305- opts .setdefault ("epilogue_fusion" , False )
306- opts .setdefault ("triton.cudagraphs" , False )
307- opts .setdefault ("shape_padding" , True )
308- opts .setdefault ("max_fusion_size" , 8 )
313+ if compile_opts :
314+ inductor_options .update (compile_opts )
309315
310316 return torch .compile (
311317 traced_lower ,
312318 backend = "inductor" ,
313319 dynamic = True ,
314- ** compile_opts ,
320+ options = inductor_options ,
315321 )
316322
317323
0 commit comments