Skip to content

Commit 478cba6

Browse files
committed
fix charge_spin & limit torch to 2.11 for compile
1 parent 049fb95 commit 478cba6

5 files changed

Lines changed: 87 additions & 45 deletions

File tree

deepmd/pt/entrypoints/freeze_pt2.py

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -354,34 +354,21 @@ def _make_sample_inputs(
354354
)
355355
charge_spin = None
356356
if dim_chg_spin > 0:
357-
default_chg_spin = model.get_default_chg_spin()
358-
if default_chg_spin is None:
359-
raise ValueError(
360-
"SeZM .pt2 freeze requires default_chg_spin when charge/spin "
361-
"conditioning is enabled; runtime charge_spin input is not exposed."
362-
)
363-
charge_spin = (
364-
default_chg_spin.to(device=device, dtype=torch.float64)
365-
.view(1, dim_chg_spin)
366-
.expand(nframes, -1)
367-
.contiguous()
357+
charge_spin = torch.zeros(
358+
nframes, dim_chg_spin, dtype=torch.float64, device=device
368359
)
369360
if has_spin:
370-
if charge_spin is not None:
371-
return (
372-
ext_coord,
373-
ext_atype,
374-
ext_spin,
375-
nlist_t,
376-
mapping_t,
377-
fparam,
378-
aparam,
379-
charge_spin,
380-
)
381-
return ext_coord, ext_atype, ext_spin, nlist_t, mapping_t, fparam, aparam
382-
if charge_spin is not None:
383-
return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
384-
return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam
361+
return (
362+
ext_coord,
363+
ext_atype,
364+
ext_spin,
365+
nlist_t,
366+
mapping_t,
367+
fparam,
368+
aparam,
369+
charge_spin,
370+
)
371+
return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin
385372

386373

387374
def _resolve_nframes(
@@ -446,6 +433,9 @@ def _build_dynamic_shapes(
446433
nloc_dim = torch.export.Dim("nloc", min=1)
447434
fparam = sample_inputs[5] if has_spin else sample_inputs[4]
448435
aparam = sample_inputs[6] if has_spin else sample_inputs[5]
436+
charge_spin = None
437+
if has_charge_spin:
438+
charge_spin = sample_inputs[7] if has_spin else sample_inputs[6]
449439
if has_spin:
450440
shapes = (
451441
{0: nframes_dim, 1: nall_dim}, # extended_coord
@@ -457,7 +447,7 @@ def _build_dynamic_shapes(
457447
{0: nframes_dim, 1: nloc_dim} if aparam is not None else None,
458448
)
459449
if has_charge_spin:
460-
shapes = (*shapes, {0: nframes_dim})
450+
shapes = (*shapes, {0: nframes_dim} if charge_spin is not None else None)
461451
return shapes
462452
shapes = (
463453
{0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3)
@@ -468,7 +458,7 @@ def _build_dynamic_shapes(
468458
{0: nframes_dim, 1: nloc_dim} if aparam is not None else None,
469459
)
470460
if has_charge_spin:
471-
shapes = (*shapes, {0: nframes_dim})
461+
shapes = (*shapes, {0: nframes_dim} if charge_spin is not None else None)
472462
return shapes
473463

474464

@@ -527,10 +517,48 @@ def freeze_sezm_to_pt2(
527517
# do_atomic_virial=True pulls every key that DeepPotPTExpt may read
528518
# (energy, energy_redu, energy_derv_r, energy_derv_c, energy_derv_c_redu)
529519
# into the traced graph.
530-
traced = model.forward_common_lower_exportable(
531-
*sample_inputs_cpu,
532-
do_atomic_virial=True,
533-
)
520+
if is_spin:
521+
(
522+
ext_coord,
523+
ext_atype,
524+
ext_spin,
525+
nlist_t,
526+
mapping_t,
527+
fparam,
528+
aparam,
529+
charge_spin,
530+
) = sample_inputs_cpu
531+
traced = model.forward_common_lower_exportable(
532+
ext_coord,
533+
ext_atype,
534+
ext_spin,
535+
nlist_t,
536+
mapping_t,
537+
fparam=fparam,
538+
aparam=aparam,
539+
charge_spin=charge_spin,
540+
do_atomic_virial=True,
541+
)
542+
else:
543+
(
544+
ext_coord,
545+
ext_atype,
546+
nlist_t,
547+
mapping_t,
548+
fparam,
549+
aparam,
550+
charge_spin,
551+
) = sample_inputs_cpu
552+
traced = model.forward_common_lower_exportable(
553+
ext_coord,
554+
ext_atype,
555+
nlist_t,
556+
mapping_t,
557+
fparam=fparam,
558+
aparam=aparam,
559+
charge_spin=charge_spin,
560+
do_atomic_virial=True,
561+
)
534562

535563
# Output key order is taken from a concrete run; Python dict order
536564
# is stable and matches what DeepPotPTExpt::extract_outputs zips

deepmd/pt/model/model/sezm_model.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,9 @@
376376
from einops import (
377377
rearrange,
378378
)
379+
from packaging.version import (
380+
Version,
381+
)
379382
from torch.fx.experimental.proxy_tensor import (
380383
make_fx,
381384
)
@@ -475,6 +478,16 @@ def _parse_optional_env_bool(var_name: str) -> bool | None:
475478
)
476479

477480

481+
def _check_compile_torch_version() -> None:
482+
"""Fail fast when SeZM compile is requested on unsupported PyTorch."""
483+
version = Version(torch.__version__).release
484+
if len(version) < 2 or version[:2] != (2, 11):
485+
raise RuntimeError(
486+
"SeZM `use_compile` and `DP_COMPILE_INFER` require PyTorch 2.11.x; "
487+
f"found torch {torch.__version__}."
488+
)
489+
490+
478491
def _strip_saved_tensor_detach(gm: torch.fx.GraphModule) -> None:
479492
"""Strip ``aten.detach`` nodes that ``make_fx`` inserts for saved tensors.
480493
@@ -614,6 +627,8 @@ def __init__(
614627
self._env_use_compile_infer: bool | None = _parse_optional_env_bool(
615628
"DP_COMPILE_INFER"
616629
)
630+
if self.use_compile or self._env_use_compile_infer is True:
631+
_check_compile_torch_version()
617632

618633
# === Bridging (optional short-range zone bridging) ===
619634
self.bridging_method: str = str(bridging_method).upper()
@@ -1828,8 +1843,9 @@ def forward_common_lower_exportable(
18281843
mapping: torch.Tensor | None = None,
18291844
fparam: torch.Tensor | None = None,
18301845
aparam: torch.Tensor | None = None,
1831-
do_atomic_virial: bool = False,
18321846
charge_spin: torch.Tensor | None = None,
1847+
*,
1848+
do_atomic_virial: bool = False,
18331849
) -> torch.nn.Module:
18341850
"""Trace ``forward_common_lower`` into an exportable FX ``GraphModule``.
18351851
@@ -1884,9 +1900,8 @@ def fn(
18841900
mapping_: torch.Tensor | None,
18851901
fparam_: torch.Tensor | None,
18861902
aparam_: torch.Tensor | None,
1887-
*maybe_charge_spin: torch.Tensor | None,
1903+
charge_spin_: torch.Tensor | None,
18881904
) -> dict[str, torch.Tensor]:
1889-
charge_spin_ = maybe_charge_spin[0] if maybe_charge_spin else None
18901905
return lower_fn(
18911906
ext_coord,
18921907
ext_atype,
@@ -1905,7 +1920,7 @@ def fn(
19051920
dtype=extended_coord.dtype,
19061921
device=extended_coord.device,
19071922
)
1908-
trace_inputs = (*trace_inputs, charge_spin)
1923+
trace_inputs = (*trace_inputs, charge_spin)
19091924

19101925
return self._trace_lower_exportable(
19111926
fn,

deepmd/pt/model/model/sezm_spin_model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,9 @@ def forward_common_lower_exportable(
301301
mapping: torch.Tensor | None = None,
302302
fparam: torch.Tensor | None = None,
303303
aparam: torch.Tensor | None = None,
304-
do_atomic_virial: bool = False,
305304
charge_spin: torch.Tensor | None = None,
305+
*,
306+
do_atomic_virial: bool = False,
306307
) -> torch.nn.Module:
307308
"""Trace the spin lower interface into an exportable FX graph."""
308309
extra_sort = self.need_sorted_nlist_for_lower()
@@ -339,9 +340,8 @@ def fn(
339340
mapping_: torch.Tensor | None,
340341
fparam_: torch.Tensor | None,
341342
aparam_: torch.Tensor | None,
342-
*maybe_charge_spin: torch.Tensor | None,
343+
charge_spin_: torch.Tensor | None,
343344
) -> dict[str, torch.Tensor]:
344-
charge_spin_ = maybe_charge_spin[0] if maybe_charge_spin else None
345345
return lower_fn(
346346
ext_coord,
347347
ext_atype,
@@ -369,7 +369,7 @@ def fn(
369369
dtype=extended_coord.dtype,
370370
device=extended_coord.device,
371371
)
372-
trace_inputs = (*trace_inputs, charge_spin)
372+
trace_inputs = (*trace_inputs, charge_spin)
373373

374374
return self._trace_lower_exportable(
375375
fn,

deepmd/pt_expt/infer/deep_eval.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,9 +1167,8 @@ def _eval_model(
11671167
mapping_t,
11681168
fparam_t,
11691169
aparam_t,
1170+
charge_spin_t,
11701171
)
1171-
if charge_spin_t is not None:
1172-
model_inputs = (*model_inputs, charge_spin_t)
11731172
if self._is_pt2:
11741173
# AOTInductor's __call__ unflattens output using stored out_spec,
11751174
# returning a dict just like the .pte module.
@@ -1320,9 +1319,8 @@ def _eval_model_spin(
13201319
mapping_t,
13211320
fparam_t,
13221321
aparam_t,
1322+
charge_spin_t,
13231323
)
1324-
if charge_spin_t is not None:
1325-
model_inputs = (*model_inputs, charge_spin_t)
13261324
if self._is_pt2:
13271325
model_ret = self._pt2_runner(*model_inputs)
13281326
else:

source/tests/pt/model/test_sezm_export.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def _eager_forward(
174174
sample_inputs: tuple,
175175
) -> dict[str, torch.Tensor]:
176176
"""Mirror the trace closure: fresh leaf coord + ``requires_grad=True``."""
177-
ext_coord, ext_atype, nlist, mapping, fparam, aparam = sample_inputs
177+
ext_coord, ext_atype, nlist, mapping, fparam, aparam, charge_spin = sample_inputs
178178
eager_coord = ext_coord.detach().clone().requires_grad_(True)
179179
return model.forward_common_lower(
180180
eager_coord,
@@ -183,6 +183,7 @@ def _eager_forward(
183183
mapping=mapping,
184184
fparam=fparam,
185185
aparam=aparam,
186+
charge_spin=charge_spin,
186187
do_atomic_virial=True,
187188
extra_nlist_sort=model.need_sorted_nlist_for_lower(),
188189
)

0 commit comments

Comments
 (0)