Skip to content

Commit aabb710

Browse files
author
Han Wang
committed
feat(pt_expt): use inductor+dynamic for training compile
Replace aot_eager+padding+manual recompile with symbolic make_fx + torch.compile(backend="inductor", dynamic=True). The compiled graph natively handles varying nframes/nloc/nall so the per-batch padding and runtime _recompile pass can be removed. Use a trace-time nframes of 7 (prime) and reshape with -1 in dpmodel (general_fitting, env_mat) to prevent PyTorch's symbolic tracer from unifying the batch dim with numb_fparam / numb_aparam / ntypes / dim_case_embd. Add TestCompiledVaryingNframesWithParams covering collisions with fparam/aparam, and TestCompileCaseEmbdVaryingNframes covering dim_case_embd > 0 with runtime nframes matching the embed dim.
1 parent 665b85a commit aabb710

5 files changed

Lines changed: 541 additions & 235 deletions

File tree

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -674,15 +674,20 @@ def _call_common(
674674
if self.numb_fparam > 0:
675675
assert fparam is not None, "fparam should not be None"
676676
try:
677-
fparam = xp.reshape(fparam, (nf, self.numb_fparam))
677+
# Use -1 for nframes so the shape is inferred from the total
678+
# size. Passing the concrete symbol `nf` here would let
679+
# torch.fx's symbolic tracer specialise when `nf` happens to
680+
# equal another tensor dim (e.g. numb_fparam), baking the
681+
# batch size into the compiled graph.
682+
fparam = xp.reshape(fparam, (-1, self.numb_fparam))
678683
except (ValueError, RuntimeError) as e:
679684
raise ValueError(
680685
f"input fparam: cannot reshape {fparam.shape} "
681686
f"into ({nf}, {self.numb_fparam})."
682687
) from e
683688
fparam = (fparam - self.fparam_avg[...]) * self.fparam_inv_std[...]
684689
fparam = xp.tile(
685-
xp.reshape(fparam, (nf, 1, self.numb_fparam)), (1, nloc, 1)
690+
xp.reshape(fparam, (-1, 1, self.numb_fparam)), (1, nloc, 1)
686691
)
687692
xx = xp.concat(
688693
[xx, fparam],
@@ -697,7 +702,9 @@ def _call_common(
697702
if self.numb_aparam > 0 and not self.use_aparam_as_mask:
698703
assert aparam is not None, "aparam should not be None"
699704
try:
700-
aparam = xp.reshape(aparam, (nf, nloc, self.numb_aparam))
705+
# Use -1 for nframes so the shape is inferred from the total
706+
# size; see the fparam branch above for rationale.
707+
aparam = xp.reshape(aparam, (-1, nloc, self.numb_aparam))
701708
except (ValueError, RuntimeError) as e:
702709
raise ValueError(
703710
f"input aparam: cannot reshape {aparam.shape} "
@@ -744,8 +751,12 @@ def _call_common(
744751
device=array_api_compat.device(descriptor),
745752
)
746753
for type_i in range(self.ntypes):
754+
# Use -1 for nframes so the shape is inferred; see the fparam
755+
# branch above for rationale (avoid symbolic-dim collision
756+
# with numb_fparam / other dims during symbolic tracing).
747757
mask = xp.tile(
748-
xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out)
758+
xp.reshape((atype == type_i), (-1, nloc, 1)),
759+
(1, 1, net_dim_out),
749760
)
750761
atom_property = self.nets[(type_i,)](xx)
751762
if self.remove_vaccum_contribution is not None and not (
@@ -761,7 +772,7 @@ def _call_common(
761772
if self.eval_return_middle_output and len(self.neuron) > 0:
762773
mid = self.nets[(type_i,)].call_until_last(xx)
763774
mid_mask = xp.tile(
764-
xp.reshape((atype == type_i), (nf, nloc, 1)),
775+
xp.reshape((atype == type_i), (-1, nloc, 1)),
765776
(1, 1, self.neuron[-1]),
766777
)
767778
mid = xp.where(mid_mask, mid, xp.zeros_like(mid))
@@ -778,7 +789,7 @@ def _call_common(
778789
xp.reshape(atype, (-1,)),
779790
axis=0,
780791
),
781-
(nf, nloc, net_dim_out),
792+
(-1, nloc, net_dim_out),
782793
)
783794
# nf x nloc
784795
exclude_mask = self.emask.build_type_exclude_mask(atype)

deepmd/dpmodel/utils/env_mat.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,22 @@ def _make_env_mat(
6868
xp = array_api_compat.array_namespace(nlist)
6969
nf, nloc, nnei = nlist.shape
7070
# nf x nall x 3
71-
coord = xp.reshape(coord, (nf, -1, 3))
71+
# Callers may pass either (nf, nall*3) or (nf, nall, 3); normalise
72+
# both to (nf, nall, 3) using -1 for nframes so the shape is inferred
73+
# from the total size. Passing the symbolic nf here can trigger
74+
# torch.fx symbolic-tracer specialisation when nf happens to collide
75+
# with another dim (e.g. numb_fparam) during training compile.
76+
if coord.ndim == 2:
77+
coord = xp.reshape(coord, (-1, coord.shape[1] // 3, 3))
7278
mask = nlist >= 0
7379
nlist = nlist * xp.astype(mask, nlist.dtype)
7480
# nf x (nloc x nnei) x 3
75-
index = xp.tile(xp.reshape(nlist, (nf, -1, 1)), (1, 1, 3))
81+
index = xp.tile(xp.reshape(nlist, (-1, nloc * nnei, 1)), (1, 1, 3))
7682
coord_r = xp_take_along_axis(coord, index, 1)
7783
# nf x nloc x nnei x 3
78-
coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3))
84+
coord_r = xp.reshape(coord_r, (-1, nloc, nnei, 3))
7985
# nf x nloc x 1 x 3
80-
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3))
86+
coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (-1, nloc, 1, 3))
8187
# nf x nloc x nnei x 3
8288
diff = coord_r - coord_l
8389
# nf x nloc x nnei

0 commit comments

Comments
 (0)