Skip to content

Commit 460ccf7

Browse files
author
Han Wang
committed
fix: move _load_custom_ops after deepmd.pt import in gen scripts
The _load_custom_ops() call was placed before deepmd.pt was imported, so the guard `hasattr(torch.ops.deepmd, "border_op")` didn't see the op registered by the standard install path. This caused the build directory's .so to be loaded, and then when deepmd.pt was imported later for .pth export, it loaded the same op again — crash. Fix: move _load_custom_ops() to after the deepmd.pt import so the guard correctly detects the already-registered op and skips.
1 parent 59a3bf7 commit 460ccf7

5 files changed

Lines changed: 13 additions & 9 deletions

File tree

source/tests/infer/gen_dpa1.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def main():
100100
get_model,
101101
)
102102

103-
# Load custom ops after deepmd import to avoid double registration
104-
_load_custom_ops()
105103
_ensure_inductor_compiler()
106104

107105
# ---- 1. DPA1 model config with type_one_side=True ----
@@ -152,6 +150,9 @@ def main():
152150
deserialize_to_file as pt_expt_deserialize_to_file,
153151
)
154152

153+
# Load custom ops after deepmd.pt import to avoid double registration
154+
_load_custom_ops()
155+
155156
base_dir = os.path.dirname(__file__)
156157

157158
pt2_path = os.path.join(base_dir, "deeppot_dpa1.pt2")

source/tests/infer/gen_dpa2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ def main():
100100
get_model,
101101
)
102102

103-
# Load custom ops after deepmd import to avoid double registration
104-
_load_custom_ops()
105103
_ensure_inductor_compiler()
106104

107105
# ---- 1. DPA2 model config with type_one_side=True, use_three_body=True ----
@@ -175,6 +173,9 @@ def main():
175173
deserialize_to_file as pt_expt_deserialize_to_file,
176174
)
177175

176+
# Load custom ops after deepmd.pt import to avoid double registration
177+
_load_custom_ops()
178+
178179
base_dir = os.path.dirname(__file__)
179180

180181
pt2_path = os.path.join(base_dir, "deeppot_dpa2.pt2")

source/tests/infer/gen_dpa3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ def main():
105105
get_model,
106106
)
107107

108-
# Load custom ops after deepmd import to avoid double registration
109-
_load_custom_ops()
110108
_ensure_inductor_compiler()
111109

112110
# ---- 1. DPA3 model config (small, fast to compile) ----
@@ -158,6 +156,9 @@ def main():
158156
deserialize_to_file as pt_expt_deserialize_to_file,
159157
)
160158

159+
# Load custom ops after deepmd.pt import to avoid double registration
160+
_load_custom_ops()
161+
161162
base_dir = os.path.dirname(__file__)
162163

163164
pt2_path = os.path.join(base_dir, "deeppot_dpa3.pt2")

source/tests/infer/gen_fparam_aparam.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ def main():
7575
get_model,
7676
)
7777

78-
# Load custom ops after deepmd import to avoid double registration
79-
_load_custom_ops()
8078
_ensure_inductor_compiler()
8179

8280
# ---- 1. Model config (type_one_side=True) ----
@@ -134,6 +132,9 @@ def main():
134132
deserialize_to_file as pt_expt_deserialize_to_file,
135133
)
136134

135+
# Load custom ops after deepmd.pt import to avoid double registration
136+
_load_custom_ops()
137+
137138
base_dir = os.path.dirname(__file__)
138139

139140
pt2_path = os.path.join(base_dir, "fparam_aparam.pt2")

source/tests/infer/gen_sea.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def main():
7878
deserialize_to_file,
7979
)
8080

81-
# Load custom ops after deepmd import to avoid double registration
81+
# Load custom ops after deepmd.pt import to avoid double registration
8282
_load_custom_ops()
8383
_ensure_inductor_compiler()
8484

0 commit comments

Comments
 (0)