Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,7 @@ def _eval_model(

results = []
for odef in request_defs:
# it seems not doing conversion
# dp_name = self._OUTDEF_DP2BACKEND[odef.name]
dp_name = odef.name
dp_name = self._OUTDEF_DP2BACKEND[odef.name]
if dp_name in batch_output:
shape = self._get_output_shape(odef, nframes, natoms)
if batch_output[dp_name] is not None:
Expand Down
83 changes: 83 additions & 0 deletions deepmd/dpmodel/model/dipole_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
Any,
)

from deepmd.dpmodel.array_api import (
Array,
)
from deepmd.dpmodel.atomic_model import (
DPDipoleAtomicModel,
)
Expand Down Expand Up @@ -31,3 +34,83 @@
) -> None:
DPModelCommon.__init__(self)
DPDipoleModel_.__init__(self, *args, **kwargs)

def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:
model_ret = self.call_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["dipole"] = model_ret["dipole"]
model_predict["global_dipole"] = model_ret["dipole_redu"]
if self.do_grad_r("dipole") and model_ret["dipole_derv_r"] is not None:
model_predict["force"] = model_ret["dipole_derv_r"]
if self.do_grad_c("dipole") and model_ret["dipole_derv_c_redu"] is not None:
model_predict["virial"] = model_ret["dipole_derv_c_redu"]
if do_atomic_virial and model_ret["dipole_derv_c"] is not None:
model_predict["atom_virial"] = model_ret["dipole_derv_c"]
if "mask" in model_ret:
Comment on lines +56 to +64

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
model_predict["mask"] = model_ret["mask"]
return model_predict

def call_lower(
self,
extended_coord: Array,
extended_atype: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:
model_ret = self.call_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["dipole"] = model_ret["dipole"]
model_predict["global_dipole"] = model_ret["dipole_redu"]
if self.do_grad_r("dipole") and model_ret.get("dipole_derv_r") is not None:
model_predict["extended_force"] = model_ret["dipole_derv_r"]
if self.do_grad_c("dipole") and model_ret.get("dipole_derv_c_redu") is not None:
model_predict["virial"] = model_ret["dipole_derv_c_redu"]
if do_atomic_virial and model_ret.get("dipole_derv_c") is not None:
model_predict["extended_virial"] = model_ret["dipole_derv_c"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def translated_output_def(self) -> dict[str, Any]:
out_def_data = self.model_output_def().get_data()
output_def = {
"dipole": out_def_data["dipole"],
"global_dipole": out_def_data["dipole_redu"],
}
if self.do_grad_r("dipole"):
output_def["force"] = out_def_data["dipole_derv_r"]
output_def["force"].squeeze(-2)
if self.do_grad_c("dipole"):
output_def["virial"] = out_def_data["dipole_derv_c_redu"]
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = out_def_data["dipole_derv_c"]
output_def["atom_virial"].squeeze(-2)
if "mask" in out_def_data:
output_def["mask"] = out_def_data["mask"]
return output_def
63 changes: 63 additions & 0 deletions deepmd/dpmodel/model/dos_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
Any,
)

from deepmd.dpmodel.array_api import (
Array,
)
from deepmd.dpmodel.atomic_model import (
DPDOSAtomicModel,
)
Expand Down Expand Up @@ -31,3 +34,63 @@
) -> None:
DPModelCommon.__init__(self)
DPDOSModel_.__init__(self, *args, **kwargs)

def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:
model_ret = self.call_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
Comment on lines +48 to +56

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
model_predict["dos"] = model_ret["dos_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

def call_lower(
self,
extended_coord: Array,
extended_atype: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:
model_ret = self.call_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_dos"] = model_ret["dos"]
model_predict["dos"] = model_ret["dos_redu"]
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

def translated_output_def(self) -> dict[str, Any]:
out_def_data = self.model_output_def().get_data()
output_def = {
"atom_dos": out_def_data["dos"],
"dos": out_def_data["dos_redu"],
}
if "mask" in out_def_data:
output_def["mask"] = out_def_data["mask"]
return output_def
85 changes: 85 additions & 0 deletions deepmd/dpmodel/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
Any,
)

from deepmd.dpmodel.array_api import (
Array,
)
from deepmd.dpmodel.atomic_model.linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
)
Expand Down Expand Up @@ -34,6 +37,88 @@
) -> None:
super().__init__(*args, **kwargs)

def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:
model_ret = self.call_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy") and model_ret["energy_derv_r"] is not None:
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy") and model_ret["energy_derv_c_redu"] is not None:
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial and model_ret["energy_derv_c"] is not None:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2)
if "mask" in model_ret:
Comment on lines +58 to +66

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
model_predict["mask"] = model_ret["mask"]
return model_predict

def call_lower(
self,
extended_coord: Array,
extended_atype: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:
model_ret = self.call_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy") and model_ret.get("energy_derv_r") is not None:
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy") and model_ret.get("energy_derv_c_redu") is not None:
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial and model_ret.get("energy_derv_c") is not None:
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
-2
)
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

def translated_output_def(self) -> dict[str, Any]:
out_def_data = self.model_output_def().get_data()
output_def = {
"atom_energy": out_def_data["energy"],
"energy": out_def_data["energy_redu"],
}
if self.do_grad_r("energy"):
output_def["force"] = out_def_data["energy_derv_r"]
output_def["force"].squeeze(-2)
if self.do_grad_c("energy"):
output_def["virial"] = out_def_data["energy_derv_c_redu"]
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = out_def_data["energy_derv_c"]
output_def["atom_virial"].squeeze(-2)
if "mask" in out_def_data:
output_def["mask"] = out_def_data["mask"]
return output_def

@classmethod
def update_sel(
cls,
Expand Down
69 changes: 69 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
Any,
)

from deepmd.dpmodel.array_api import (
Array,
)
from deepmd.dpmodel.atomic_model import (
DPEnergyAtomicModel,
)
Expand Down Expand Up @@ -48,6 +51,72 @@
return self.hess_fitting_def
return super().atomic_output_def()

def call(
self,
coord: Array,
atype: Array,
box: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:
model_ret = self.call_common(
coord,
atype,
box,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy") and model_ret["energy_derv_r"] is not None:
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
if self.do_grad_c("energy") and model_ret["energy_derv_c_redu"] is not None:
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial and model_ret["energy_derv_c"] is not None:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2)
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
if self._enable_hessian and model_ret.get("energy_derv_r_derv_r") is not None:
model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-3)
return model_predict

def call_lower(
self,
extended_coord: Array,
extended_atype: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
aparam: Array | None = None,
do_atomic_virial: bool = False,
) -> dict[str, Array]:
model_ret = self.call_common_lower(
extended_coord,
extended_atype,
nlist,
mapping,
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
)
model_predict = {}
model_predict["atom_energy"] = model_ret["energy"]
model_predict["energy"] = model_ret["energy_redu"]
if self.do_grad_r("energy") and model_ret.get("energy_derv_r") is not None:
model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2)
Comment on lines +101 to +109

Check warning

Code scanning / CodeQL

Signature mismatch in overriding method Warning

This method does not accept arbitrary keyword arguments, which overridden
NativeOP.call
does.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at most 7 positional arguments, whereas overridden
NativeOP.call
may be called with arbitrarily many.
This call
correctly calls the base method, but does not match the signature of the overriding method.
This method requires at least 3 positional arguments, whereas overridden
NativeOP.call
may be called with 1.
This call
correctly calls the base method, but does not match the signature of the overriding method.
if self.do_grad_c("energy") and model_ret.get("energy_derv_c_redu") is not None:
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial and model_ret.get("energy_derv_c") is not None:
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
-2
)
if "mask" in model_ret:
model_predict["mask"] = model_ret["mask"]
return model_predict

def translated_output_def(self) -> dict[str, Any]:
"""Get the translated output definition.

Expand Down
11 changes: 5 additions & 6 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def enable_compression(
check_frequency,
)

def call(
def call_common(
self,
coord: Array,
atype: Array,
Expand Down Expand Up @@ -262,7 +262,7 @@ def call(
)
del coord, box, fparam, aparam
model_predict = model_call_from_call_lower(
call_lower=self.call_lower,
call_lower=self.call_common_lower,
rcut=self.get_rcut(),
sel=self.get_sel(),
mixed_types=self.mixed_types(),
Expand All @@ -277,7 +277,7 @@ def call(
model_predict = self._output_type_cast(model_predict, input_prec)
return model_predict

def call_lower(
def call_common_lower(
self,
extended_coord: Array,
extended_atype: Array,
Expand Down Expand Up @@ -365,9 +365,8 @@ def forward_common_atomic(
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
)

forward_lower = call_lower
forward_common = call
forward_common_lower = call_lower
call = call_common
call_lower = call_common_lower

def get_out_bias(self) -> Array:
"""Get the output bias."""
Expand Down
Loading
Loading