Skip to content

Commit a183f95

Browse files
author
Han Wang
committed
feat: dynamic compile with inductor for training
1 parent baab3e8 commit a183f95

2 files changed

Lines changed: 100 additions & 221 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 83 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,21 @@ def _trace_and_compile(
153153
) -> torch.nn.Module:
154154
"""Trace ``forward_lower`` with ``make_fx`` and compile with ``torch.compile``.
155155
156+
Uses symbolic tracing (``tracing_mode="symbolic"``) so the resulting
157+
FX graph captures shape-polymorphic operations. The graph is then
158+
compiled with ``torch.compile(dynamic=True)`` and the inductor
159+
backend, which automatically pads tensor shapes for efficient kernel
160+
execution (``shape_padding=True``).
161+
156162
Parameters
157163
----------
158164
model : torch.nn.Module
159-
The (uncompiled) model. Temporarily set to eval mode for tracing.
165+
The (uncompiled) model.
160166
ext_coord, ext_atype, nlist, mapping, fparam, aparam
161-
Sample tensors (already padded to the desired max_nall).
167+
Sample tensors used to drive the symbolic trace.
162168
compile_opts : dict
163-
Options forwarded to ``torch.compile`` (excluding ``dynamic``).
169+
Options forwarded to ``torch.compile``. Keys ``dynamic`` and
170+
``backend`` are set internally and ignored if provided.
164171
165172
Returns
166173
-------
@@ -197,84 +204,52 @@ def fn(
197204
aparam=aparam,
198205
)
199206

200-
# Use default tracing_mode="real" (concrete shapes) for best
201-
# runtime performance. If data-dependent intermediate shapes
202-
# change at runtime, the caller catches the error and retraces.
203-
traced_lower = make_fx(fn)(ext_coord, ext_atype, nlist, mapping, fparam, aparam)
207+
# Symbolic tracing captures shape-polymorphic ops, pairing with
208+
# dynamic=True in torch.compile to handle varying nall without
209+
# manual padding or recompilation.
210+
traced_lower = make_fx(
211+
fn,
212+
tracing_mode="symbolic",
213+
_allow_non_fake_inputs=True,
214+
)(ext_coord, ext_atype, nlist, mapping, fparam, aparam)
204215

205216
if not was_training:
206217
model.eval()
207218

208-
# The inductor backend does not propagate gradients through the
209-
# make_fx-decomposed autograd.grad ops (second-order gradients for
210-
# force training). Use "aot_eager" which correctly preserves the
211-
# gradient chain while still benefiting from make_fx decomposition.
212-
if "backend" not in compile_opts:
213-
compile_opts["backend"] = "aot_eager"
214-
compiled_lower = torch.compile(traced_lower, dynamic=False, **compile_opts)
219+
# Override backend and dynamic — the inductor backend with
220+
# dynamic=True handles varying shapes automatically.
221+
compile_opts.pop("dynamic", None)
222+
compile_opts.pop("backend", None)
223+
if "options" not in compile_opts:
224+
compile_opts["options"] = {}
225+
compile_opts["options"].setdefault("shape_padding", True)
226+
227+
compiled_lower = torch.compile(
228+
traced_lower,
229+
backend="inductor",
230+
dynamic=True,
231+
**compile_opts,
232+
)
215233
return compiled_lower
216234

217235

218236
class _CompiledModel(torch.nn.Module):
219-
"""Coord extension (eager) -> pad nall -> compiled forward_lower.
237+
"""Coord extension (eager) -> compiled forward_lower.
220238
221-
If a batch's ``nall`` exceeds the current ``max_nall``, the model is
222-
automatically re-traced and recompiled with a larger pad size.
239+
Coord extension and neighbor list construction involve data-dependent
240+
control flow and are kept in eager mode. The compiled ``forward_lower``
241+
handles varying ``nall`` via ``dynamic=True`` — no manual padding or
242+
recompilation needed.
223243
"""
224244

225245
def __init__(
226246
self,
227247
original_model: torch.nn.Module,
228248
compiled_forward_lower: torch.nn.Module,
229-
max_nall: int,
230-
compile_opts: dict[str, Any],
231249
) -> None:
232250
super().__init__()
233251
self.original_model = original_model
234252
self.compiled_forward_lower = compiled_forward_lower
235-
self._max_nall = max_nall
236-
self._compile_opts = compile_opts
237-
238-
def _recompile(
239-
self,
240-
ext_coord: torch.Tensor,
241-
ext_atype: torch.Tensor,
242-
nlist: torch.Tensor,
243-
mapping: torch.Tensor,
244-
fparam: torch.Tensor | None,
245-
aparam: torch.Tensor | None,
246-
new_max_nall: int,
247-
) -> None:
248-
"""Re-trace and recompile for the given inputs.
249-
250-
If *new_max_nall* differs from the current ``_max_nall``, the
251-
inputs are padded (or already padded by the caller).
252-
"""
253-
# Pad if the caller provides unpadded tensors (nall growth case)
254-
actual_nall = ext_coord.shape[1]
255-
pad_n = new_max_nall - actual_nall
256-
if pad_n > 0:
257-
ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n))
258-
ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n))
259-
mapping = torch.nn.functional.pad(mapping, (0, pad_n))
260-
261-
ext_coord = ext_coord.detach()
262-
263-
self.compiled_forward_lower = _trace_and_compile(
264-
self.original_model,
265-
ext_coord,
266-
ext_atype,
267-
nlist,
268-
mapping,
269-
fparam,
270-
aparam,
271-
self._compile_opts,
272-
)
273-
self._max_nall = new_max_nall
274-
log.info(
275-
"Recompiled model with max_nall=%d.",
276-
new_max_nall,
277-
)
278253

279254
def forward(
280255
self,
@@ -318,27 +293,6 @@ def forward(
318293
distinguish_types=False,
319294
)
320295
ext_coord = ext_coord.reshape(nframes, -1, 3)
321-
322-
# Grow max_nall if needed (retrace + recompile)
323-
actual_nall = ext_coord.shape[1]
324-
if actual_nall > self._max_nall:
325-
new_max_nall = ((int(actual_nall * 1.2) + 7) // 8) * 8
326-
log.info(
327-
"nall=%d exceeds max_nall=%d; recompiling with max_nall=%d.",
328-
actual_nall,
329-
self._max_nall,
330-
new_max_nall,
331-
)
332-
self._recompile(
333-
ext_coord, ext_atype, nlist, mapping, fparam, aparam, new_max_nall
334-
)
335-
336-
# Pad to max_nall so compiled graph sees a fixed shape
337-
pad_n = self._max_nall - actual_nall
338-
if pad_n > 0:
339-
ext_coord = torch.nn.functional.pad(ext_coord, (0, 0, 0, pad_n))
340-
ext_atype = torch.nn.functional.pad(ext_atype, (0, pad_n))
341-
mapping = torch.nn.functional.pad(mapping, (0, pad_n))
342296
ext_coord = ext_coord.detach().requires_grad_(True)
343297

344298
result = self.compiled_forward_lower(
@@ -350,22 +304,18 @@ def forward(
350304
# Ghost-atom forces must be scatter-summed back to local atoms
351305
# via ``mapping`` — the same operation ``communicate_extended_output``
352306
# performs in the uncompiled path.
307+
actual_nall = ext_coord.shape[1]
353308
out: dict[str, torch.Tensor] = {}
354309
out["atom_energy"] = result["atom_energy"]
355310
out["energy"] = result["energy"]
356311
if "extended_force" in result:
357-
ext_force = result["extended_force"] # (nf, nall_padded, 3)
358-
# mapping may be padded; only use actual_nall entries
359-
map_actual = mapping[:, :actual_nall] # (nf, actual_nall)
360-
ext_force_actual = ext_force[:, :actual_nall, :] # (nf, actual_nall, 3)
312+
ext_force = result["extended_force"] # (nf, nall, 3)
361313
# scatter-sum extended forces onto local atoms
362-
idx = map_actual.unsqueeze(-1).expand_as(
363-
ext_force_actual
364-
) # (nf, actual_nall, 3)
314+
idx = mapping.unsqueeze(-1).expand_as(ext_force) # (nf, nall, 3)
365315
force = torch.zeros(
366316
nframes, nloc, 3, dtype=ext_force.dtype, device=ext_force.device
367317
)
368-
force.scatter_add_(1, idx, ext_force_actual)
318+
force.scatter_add_(1, idx, ext_force)
369319
out["force"] = force
370320
if "virial" in result:
371321
out["virial"] = result["virial"]
@@ -642,21 +592,19 @@ def get_sample() -> list[dict[str, np.ndarray]]:
642592
def _compile_model(self, compile_opts: dict[str, Any]) -> None:
643593
"""Replace ``self.model`` with a compiled version.
644594
645-
The model's ``forward`` uses ``torch.autograd.grad`` (for force
646-
computation) with ``create_graph=True``, which creates a "double
647-
backward" that ``torch.compile`` cannot handle.
648-
649-
Solution: use ``make_fx`` to trace ``forward_lower``, decomposing
650-
``torch.autograd.grad`` into primitive ops. The coord extension +
651-
nlist build (data-dependent control flow) are kept outside the
652-
compiled region.
653-
654-
To avoid the overhead of symbolic tracing and dynamic shapes, the
655-
extended-atom dimension (nall) is padded to a fixed maximum
656-
estimated from the training data. This allows concrete-shape
657-
tracing and ``dynamic=False``. If a batch exceeds the current
658-
max_nall at runtime, the model is automatically re-traced and
659-
recompiled with a larger pad size.
595+
The model's ``forward`` uses ``torch.autograd.grad`` (for forces)
596+
with ``create_graph=True``, which creates a "double backward" that
597+
``torch.compile`` cannot handle.
598+
599+
Solution: use ``make_fx`` with ``tracing_mode="symbolic"`` to trace
600+
``forward_lower``, decomposing ``torch.autograd.grad`` into
601+
primitive ops with symbolic shapes. The traced graph is compiled
602+
with ``torch.compile(dynamic=True, backend="inductor")`` so
603+
varying ``nall`` across batches is handled automatically — no
604+
manual padding or recompilation needed.
605+
606+
Coord extension + nlist build (data-dependent control flow) are
607+
kept outside the compiled region.
660608
"""
661609
from deepmd.dpmodel.utils.nlist import (
662610
build_neighbor_list,
@@ -668,105 +616,53 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
668616

669617
model = self.model
670618

671-
# --- Estimate max_nall by sampling multiple batches ---
672-
n_sample = 20
673-
max_nall = 0
674-
best_sample: (
675-
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, dict] | None
676-
) = None
677-
678-
for _ii in range(n_sample):
679-
inp, _ = self.get_data(is_train=True)
680-
coord = inp["coord"].detach()
681-
atype = inp["atype"].detach()
682-
box = inp.get("box")
683-
if box is not None:
684-
box = box.detach()
685-
686-
nframes, nloc = atype.shape[:2]
687-
coord_np = coord.cpu().numpy().reshape(nframes, nloc, 3)
688-
atype_np = atype.cpu().numpy()
689-
box_np = box.cpu().numpy().reshape(nframes, 9) if box is not None else None
690-
691-
if box_np is not None:
692-
coord_norm = normalize_coord(coord_np, box_np.reshape(nframes, 3, 3))
693-
else:
694-
coord_norm = coord_np
619+
# --- Get one sample batch to drive the symbolic trace ---
620+
inp, _ = self.get_data(is_train=True)
621+
coord = inp["coord"].detach()
622+
atype = inp["atype"].detach()
623+
box = inp.get("box")
624+
if box is not None:
625+
box = box.detach()
695626

696-
ext_coord_np, ext_atype_np, mapping_np = extend_coord_with_ghosts(
697-
coord_norm, atype_np, box_np, model.get_rcut()
698-
)
699-
nlist_np = build_neighbor_list(
700-
ext_coord_np,
701-
ext_atype_np,
702-
nloc,
703-
model.get_rcut(),
704-
model.get_sel(),
705-
distinguish_types=False,
706-
)
707-
ext_coord_np = ext_coord_np.reshape(nframes, -1, 3)
708-
nall = ext_coord_np.shape[1]
709-
if nall > max_nall:
710-
max_nall = nall
711-
best_sample = (
712-
ext_coord_np,
713-
ext_atype_np,
714-
mapping_np,
715-
nlist_np,
716-
nloc,
717-
inp,
718-
)
627+
nframes, nloc = atype.shape[:2]
628+
coord_3d = coord.reshape(nframes, nloc, 3)
629+
box_flat = box.reshape(nframes, 9) if box is not None else None
719630

720-
# Add 20 % margin and round up to a multiple of 8.
721-
max_nall = ((int(max_nall * 1.2) + 7) // 8) * 8
722-
log.info(
723-
"Estimated max_nall=%d for compiled model (sampled %d batches).",
724-
max_nall,
725-
n_sample,
726-
)
631+
if box_flat is not None:
632+
coord_norm = normalize_coord(coord_3d, box_flat.reshape(nframes, 3, 3))
633+
else:
634+
coord_norm = coord_3d
727635

728-
# --- Pad the largest sample to max_nall and trace ---
729-
assert best_sample is not None
730-
ext_coord_np, ext_atype_np, mapping_np, nlist_np, nloc, sample_input = (
731-
best_sample
636+
ext_coord, ext_atype, mapping = extend_coord_with_ghosts(
637+
coord_norm, atype, box_flat, model.get_rcut()
732638
)
733-
nframes = ext_coord_np.shape[0]
734-
actual_nall = ext_coord_np.shape[1]
735-
pad_n = max_nall - actual_nall
736-
737-
if pad_n > 0:
738-
ext_coord_np = np.pad(ext_coord_np, ((0, 0), (0, pad_n), (0, 0)))
739-
ext_atype_np = np.pad(ext_atype_np, ((0, 0), (0, pad_n)))
740-
mapping_np = np.pad(mapping_np, ((0, 0), (0, pad_n)))
741-
742-
ext_coord = torch.tensor(
743-
ext_coord_np, dtype=GLOBAL_PT_FLOAT_PRECISION, device=DEVICE
639+
nlist_t = build_neighbor_list(
640+
ext_coord,
641+
ext_atype,
642+
nloc,
643+
model.get_rcut(),
644+
model.get_sel(),
645+
distinguish_types=False,
744646
)
745-
ext_atype = torch.tensor(ext_atype_np, dtype=torch.int64, device=DEVICE)
746-
nlist_t = torch.tensor(nlist_np, dtype=torch.int64, device=DEVICE)
747-
mapping_t = torch.tensor(mapping_np, dtype=torch.int64, device=DEVICE)
748-
fparam = sample_input.get("fparam")
749-
aparam = sample_input.get("aparam")
647+
ext_coord = ext_coord.reshape(nframes, -1, 3)
750648

751-
compile_opts.pop("dynamic", None) # always False for padded approach
649+
fparam = inp.get("fparam")
650+
aparam = inp.get("aparam")
752651

753652
compiled_lower = _trace_and_compile(
754653
model,
755654
ext_coord,
756655
ext_atype,
757656
nlist_t,
758-
mapping_t,
657+
mapping,
759658
fparam,
760659
aparam,
761660
compile_opts,
762661
)
763662

764-
self.wrapper.model = _CompiledModel(
765-
model, compiled_lower, max_nall, compile_opts
766-
)
663+
self.wrapper.model = _CompiledModel(model, compiled_lower)
767664
log.info(
768-
"Model compiled with padded nall=%d (tracing_mode=real, dynamic=False).",
769-
max_nall,
665+
"Model compiled (tracing_mode=symbolic, dynamic=True, backend=inductor).",
770666
)
771667

772668
# ------------------------------------------------------------------

0 commit comments

Comments
 (0)