Skip to content

Commit 447a572

Browse files
author
Han Wang
committed
fix(pt_expt): tune inductor options for compile training
- max_fusion_size 8 → 64 to avoid scheduler timeouts on large descriptors - add triton.mix_order_reduction=False for PyTorch <=2.11 bugs (pytorch/pytorch#174379, #178080, #179494) - hardcode defaults, let user compile_options override per-key
1 parent f834202 commit 447a572

1 file changed

Lines changed: 20 additions & 14 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _trace_and_compile(
189189
mapping: torch.Tensor,
190190
fparam: torch.Tensor | None,
191191
aparam: torch.Tensor | None,
192-
compile_opts: dict[str, Any],
192+
compile_opts: dict[str, Any] | None = None,
193193
) -> torch.nn.Module:
194194
"""Symbolic-trace ``forward_lower`` and compile with inductor + dynamic=True.
195195
@@ -199,9 +199,9 @@ def _trace_and_compile(
199199
The (uncompiled) model.
200200
ext_coord, ext_atype, nlist, mapping, fparam, aparam
201201
Sample tensors used to seed the symbolic tracer.
202-
compile_opts : dict
203-
Options forwarded to ``torch.compile`` (the ``dynamic`` and
204-
``backend`` keys are ignored and replaced).
202+
compile_opts : dict or None
203+
User-supplied inductor options. These are merged on top of the
204+
built-in defaults (user values take precedence).
205205
206206
Returns
207207
-------
@@ -296,22 +296,28 @@ def _expand(t: torch.Tensor | None) -> torch.Tensor | None:
296296
if not was_training:
297297
model.eval()
298298

299-
# Work on a copy; ignore caller-supplied dynamic/backend.
300-
compile_opts = {
301-
k: v for k, v in compile_opts.items() if k not in ("dynamic", "backend")
299+
# Inductor defaults tuned for second-order-gradient training graphs.
300+
# User-supplied compile_opts override these on a per-key basis.
301+
inductor_options: dict[str, Any] = {
302+
"max_autotune": False,
303+
"shape_padding": True,
304+
"epilogue_fusion": False,
305+
"triton.cudagraphs": False,
306+
"max_fusion_size": 64,
307+
# NOTE: mix_order_reduction hits multiple bugs under
308+
# data-dependent symbolic shapes on PyTorch <=2.11
309+
# (pytorch/pytorch#174379, #178080, #179494) -- our
310+
# edge count is exactly that kind of shape.
311+
"triton.mix_order_reduction": False,
302312
}
303-
opts = compile_opts.setdefault("options", {})
304-
opts.setdefault("max_autotune", False)
305-
opts.setdefault("epilogue_fusion", False)
306-
opts.setdefault("triton.cudagraphs", False)
307-
opts.setdefault("shape_padding", True)
308-
opts.setdefault("max_fusion_size", 8)
313+
if compile_opts:
314+
inductor_options.update(compile_opts)
309315

310316
return torch.compile(
311317
traced_lower,
312318
backend="inductor",
313319
dynamic=True,
314-
**compile_opts,
320+
options=inductor_options,
315321
)
316322

317323

0 commit comments

Comments
 (0)