Skip to content

Commit 1904746

Browse files
Merge branch 'devel' into refine_pd_UT
2 parents 8680a43 + 188dae3 commit 1904746

16 files changed

Lines changed: 123 additions & 23 deletions

File tree

.github/workflows/test_cuda.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
runs-on: nvidia
2020
# https://github.com/deepmodeling/deepmd-kit/pull/2884#issuecomment-1744216845
2121
container:
22-
image: nvidia/cuda:12.6.2-cudnn-devel-ubuntu22.04
22+
image: nvidia/cuda:12.9.1-cudnn-devel-ubuntu22.04
2323
options: --gpus all
2424
if: github.repository_owner == 'deepmodeling' && (github.event_name == 'pull_request' && github.event.label && github.event.label.name == 'Test CUDA' || github.event_name == 'workflow_dispatch' || github.event_name == 'merge_group')
2525
steps:
@@ -49,6 +49,8 @@ jobs:
4949
export TENSORFLOW_ROOT=$(python -c 'import importlib.util,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)')
5050
pip install --find-links "https://www.paddlepaddle.org.cn/packages/nightly/cu126/paddlepaddle-gpu/" --index-url https://pypi.org/simple "paddlepaddle-gpu==3.3.0.dev20251204"
5151
source/install/uv_with_retry.sh pip install --system -v -e .[gpu,test,lmp,cu12,torch,jax] mpi4py --reinstall-package deepmd-kit
52+
# See https://github.com/jax-ml/jax/issues/29042
53+
source/install/uv_with_retry.sh pip install --system -U 'nvidia-cublas-cu12>=12.9.0.13'
5254
env:
5355
DP_VARIANT: cuda
5456
DP_ENABLE_NATIVE_OPTIMIZATION: 1

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151
self.type_map = type_map
5252
self.descriptor = descriptor
5353
self.fitting = fitting
54+
if hasattr(self.fitting, "reinit_exclude"):
55+
self.fitting.reinit_exclude(self.atom_exclude_types)
5456
self.type_map = type_map
5557
super().init_out_stat()
5658

@@ -191,7 +193,7 @@ def change_type_map(
191193
if model_with_new_type_stat is not None
192194
else None,
193195
)
194-
self.fitting_net.change_type_map(type_map=type_map)
196+
self.fitting.change_type_map(type_map=type_map)
195197

196198
def serialize(self) -> dict:
197199
dd = super().serialize()

deepmd/dpmodel/infer/deep_eval.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def __init__(
9090

9191
model_data = load_dp_model(model_file)
9292
self.dp = BaseModel.deserialize(model_data["model"])
93+
self.dp.model_def_script = json.dumps(model_data.get("model_def_script", {}))
9394
self.rcut = self.dp.get_rcut()
9495
self.type_map = self.dp.get_type_map()
9596
if isinstance(auto_batch_size, bool):

deepmd/jax/fitting/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
4343
"fparam_inv_std",
4444
"aparam_avg",
4545
"aparam_inv_std",
46+
"case_embd",
4647
"default_fparam_tensor",
4748
}:
4849
value = to_jax_array(value)

deepmd/jax/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(
104104
stablehlo_atomic_virial_no_ghost=model_data["@variables"][
105105
"stablehlo_atomic_virial_no_ghost"
106106
].tobytes(),
107-
model_def_script=model_data["model_def_script"],
107+
model_def_script=json.dumps(model_data["model_def_script"]),
108108
**model_data["constants"],
109109
)
110110
elif model_file.endswith(".savedmodel"):

deepmd/pd/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ def eval_typeebd(self) -> np.ndarray:
726726
typeebd = paddle.concat(out, axis=1)
727727
return to_numpy_array(typeebd)
728728

729-
def get_model_def_script(self) -> str:
729+
def get_model_def_script(self) -> dict:
730730
"""Get model definition script."""
731731
return self.model_def_script
732732

deepmd/pt/infer/deep_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,7 +684,7 @@ def eval_typeebd(self) -> np.ndarray:
684684
typeebd = torch.cat(out, dim=1)
685685
return to_numpy_array(typeebd)
686686

687-
def get_model_def_script(self) -> str:
687+
def get_model_def_script(self) -> dict:
688688
"""Get model definition script."""
689689
return self.model_def_script
690690

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
self.rcut = self.descriptor.get_rcut()
6565
self.sel = self.descriptor.get_sel()
6666
self.fitting_net = fitting
67+
if hasattr(self.fitting_net, "reinit_exclude"):
68+
self.fitting_net.reinit_exclude(self.atom_exclude_types)
6769
super().init_out_stat()
6870
self.enable_eval_descriptor_hook = False
6971
self.enable_eval_fitting_last_layer_hook = False
@@ -151,6 +153,9 @@ def change_type_map(
151153
else None,
152154
)
153155
self.fitting_net.change_type_map(type_map=type_map)
156+
# Reinitialize fitting to get correct sel_type
157+
if hasattr(self.fitting_net, "reinit_exclude"):
158+
self.fitting_net.reinit_exclude(self.atom_exclude_types)
154159

155160
def has_message_passing(self) -> bool:
156161
"""Returns whether the atomic model has message passing."""

deepmd/tf/model/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,16 +1003,23 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
10031003
check_version_compatibility(data.pop("@version", 2), 2, 1)
10041004
descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix)
10051005
# bias_atom_e and out_bias are now completely independent - no conversion needed
1006-
fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix)
1006+
fitting_dict = data.pop("fitting", {})
1007+
atom_exclude_types = data.pop("atom_exclude_types", [])
1008+
if len(atom_exclude_types) > 0:
1009+
# get sel_type from complement of atom_exclude_types
1010+
full_type_list = np.arange(len(data["type_map"]), dtype=int)
1011+
sel_type = np.setdiff1d(
1012+
full_type_list, atom_exclude_types, assume_unique=True
1013+
)
1014+
fitting_dict["sel_type"] = sel_type.tolist()
1015+
fitting = Fitting.deserialize(fitting_dict, suffix=suffix)
10071016
# pass descriptor type embedding to model
10081017
if descriptor.explicit_ntypes:
10091018
type_embedding = descriptor.type_embedding
10101019
fitting.dim_descrpt -= type_embedding.neuron[-1]
10111020
else:
10121021
type_embedding = None
10131022
# BEGINE not supported keys
1014-
if len(data.pop("atom_exclude_types")) > 0:
1015-
raise NotImplementedError("atom_exclude_types is not supported")
10161023
if len(data.pop("pair_exclude_types")) > 0:
10171024
raise NotImplementedError("pair_exclude_types is not supported")
10181025
data.pop("rcond", None)

source/tests/array_api_strict/fitting/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
3131
"fparam_inv_std",
3232
"aparam_avg",
3333
"aparam_inv_std",
34+
"case_embd",
3435
"default_fparam_tensor",
3536
}:
3637
value = to_array_api_strict_array(value)

0 commit comments

Comments
 (0)