Skip to content

Commit 64dc703

Browse files
author
Han Wang
committed
fix(pt_expt): move optimize_ddp into _compile_model, resolve test symlinks
Move torch._dynamo.config.optimize_ddp = False from module level into _compile_model() so it only applies when compile is active. Resolve symlinks in test_fitting_stat.py data paths for reliable CI access.
1 parent 975db17 commit 64dc703

2 files changed

Lines changed: 17 additions & 12 deletions

File tree

deepmd/pt_expt/train/training.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,6 @@
2424
import torch
2525
import torch.distributed as dist
2626

27-
# Disable DDPOptimizer: our compile region wraps only the inner compute
28-
# function, not the whole DDP model. DDPOptimizer assumes it owns the
29-
# full model graph and splits at bucket boundaries, producing subgraphs
30-
# whose outputs include symbolic integers. AOT Autograd then crashes
31-
# with ``'int' object has no attribute 'meta'``
32-
# (pytorch/pytorch#134182).
33-
torch._dynamo.config.optimize_ddp = False
34-
3527
from deepmd.dpmodel.common import (
3628
to_numpy_array,
3729
)
@@ -903,6 +895,14 @@ def _compile_model(self, compile_opts: dict[str, Any]) -> None:
903895
needed. The coord extension + nlist build (data-dependent
904896
control flow) are kept outside the compiled region.
905897
"""
898+
# Disable DDPOptimizer: our compile region wraps only the inner
899+
# compute function, not the whole DDP model. DDPOptimizer assumes
900+
# it owns the full model graph and splits at bucket boundaries,
901+
# producing subgraphs whose outputs include symbolic integers.
902+
# AOT Autograd then crashes with ``'int' object has no attribute
903+
# 'meta'`` (pytorch/pytorch#134182).
904+
torch._dynamo.config.optimize_ddp = False
905+
906906
from deepmd.dpmodel.utils.nlist import (
907907
build_neighbor_list,
908908
extend_coord_with_ghosts,

source/tests/pt_expt/fitting/test_fitting_stat.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,18 @@ def _get_weighted_fitting_stat(
117117
return weighted_avg, weighted_std
118118

119119

120-
# Paths to the water data used by PT tests
121-
_PT_DATA = str(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0")
120+
# Paths to the water data used by PT tests.
121+
# resolve() follows the ``pt/water -> model/water`` symlink so numpy can
122+
# always open the real file, even on CI runners where symlink handling
123+
# can be fragile.
124+
_PT_DATA = str(
125+
(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_0").resolve()
126+
)
122127
_PT_DATA_NO_FPARAM = str(
123-
Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1"
128+
(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "data_1").resolve()
124129
)
125130
_PT_DATA_SINGLE = str(
126-
Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single"
131+
(Path(__file__).parent.parent.parent / "pt" / "water" / "data" / "single").resolve()
127132
)
128133

129134
_descriptor_se_e2_a = {

0 commit comments

Comments
 (0)