Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,9 @@ def call(
)
# (nframes, nloc, 1)
bias = bias[..., None] * scale_atype
eye = xp.eye(3, dtype=descriptor.dtype)
eye = xp.eye(
3, dtype=descriptor.dtype, device=array_api_compat.device(descriptor)
)
eye = xp.tile(eye, (nframes, nloc, 1, 1))
# (nframes, nloc, 3, 3)
bias = bias[..., None] * eye
Expand Down
16 changes: 16 additions & 0 deletions deepmd/pt_expt/fitting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,31 @@
from .base_fitting import (
BaseFitting,
)
from .dipole_fitting import (
DipoleFitting,
)
from .dos_fitting import (
DOSFittingNet,
)
from .ener_fitting import (
EnergyFittingNet,
)
from .invar_fitting import (
InvarFitting,
)
from .polarizability_fitting import (
PolarFitting,
)
from .property_fitting import (
PropertyFittingNet,
)

__all__ = [
"BaseFitting",
"DOSFittingNet",
"DipoleFitting",
"EnergyFittingNet",
"InvarFitting",
"PolarFitting",
"PropertyFittingNet",
]
23 changes: 23 additions & 0 deletions deepmd/pt_expt/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.fitting.dipole_fitting import DipoleFitting as DipoleFittingDP
from deepmd.pt_expt.common import (
register_dpmodel_mapping,
torch_module,
)

from .base_fitting import (
BaseFitting,
)


@BaseFitting.register("dipole")
@torch_module
class DipoleFitting(DipoleFittingDP):
pass


register_dpmodel_mapping(
DipoleFittingDP,
lambda v: DipoleFitting.deserialize(v.serialize()),
)
23 changes: 23 additions & 0 deletions deepmd/pt_expt/fitting/dos_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.pt_expt.common import (
register_dpmodel_mapping,
torch_module,
)

from .base_fitting import (
BaseFitting,
)


@BaseFitting.register("dos")
@torch_module
class DOSFittingNet(DOSFittingNetDP):
pass


register_dpmodel_mapping(
DOSFittingNetDP,
lambda v: DOSFittingNet.deserialize(v.serialize()),
)
23 changes: 23 additions & 0 deletions deepmd/pt_expt/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.fitting.polarizability_fitting import PolarFitting as PolarFittingDP
from deepmd.pt_expt.common import (
register_dpmodel_mapping,
torch_module,
)

from .base_fitting import (
BaseFitting,
)


@BaseFitting.register("polar")
@torch_module
class PolarFitting(PolarFittingDP):
pass


register_dpmodel_mapping(
PolarFittingDP,
lambda v: PolarFitting.deserialize(v.serialize()),
)
25 changes: 25 additions & 0 deletions deepmd/pt_expt/fitting/property_fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

from deepmd.dpmodel.fitting.property_fitting import (
PropertyFittingNet as PropertyFittingNetDP,
)
from deepmd.pt_expt.common import (
register_dpmodel_mapping,
torch_module,
)

from .base_fitting import (
BaseFitting,
)


@BaseFitting.register("property")
@torch_module
class PropertyFittingNet(PropertyFittingNetDP):
pass


register_dpmodel_mapping(
PropertyFittingNetDP,
lambda v: PropertyFittingNet.deserialize(v.serialize()),
)
25 changes: 25 additions & 0 deletions source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_PT_EXPT,
INSTALLED_TF,
CommonTest,
parameterized,
Expand All @@ -33,6 +34,13 @@
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
else:
DipoleFittingPT = object
if INSTALLED_PT_EXPT:
from deepmd.pt_expt.fitting.dipole_fitting import (
DipoleFitting as DipoleFittingPTExpt,
)
from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE
else:
DipoleFittingPTExpt = None
if INSTALLED_TF:
from deepmd.tf.fit.dipole import DipoleFittingSeA as DipoleFittingTF
else:
Expand Down Expand Up @@ -116,12 +124,17 @@ def skip_pt(self) -> bool:
tf_class = DipoleFittingTF
dp_class = DipoleFittingDP
pt_class = DipoleFittingPT
pt_expt_class = DipoleFittingPTExpt
jax_class = DipoleFittingJAX
array_api_strict_class = DipoleFittingArrayAPIStrict
args = fitting_dipole()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

@property
def skip_pt_expt(self) -> bool:
return CommonTest.skip_pt_expt

def setUp(self) -> None:
CommonTest.setUp(self)

Expand Down Expand Up @@ -184,6 +197,18 @@ def eval_pt(self, pt_obj: Any) -> Any:
.numpy()
)

def eval_pt_expt(self, pt_expt_obj: Any) -> Any:
return (
pt_expt_obj(
torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE),
torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE),
gr=torch.from_numpy(self.gr).to(device=PT_EXPT_DEVICE),
)["dipole"]
.detach()
.cpu()
.numpy()
)

def eval_dp(self, dp_obj: Any) -> Any:
(
resnet_dt,
Expand Down
36 changes: 36 additions & 0 deletions source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_PT_EXPT,
INSTALLED_TF,
CommonTest,
parameterized,
Expand All @@ -33,6 +34,11 @@
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
else:
DOSFittingPT = object
if INSTALLED_PT_EXPT:
from deepmd.pt_expt.fitting.dos_fitting import DOSFittingNet as DOSFittingPTExpt
from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE
else:
DOSFittingPTExpt = None
if INSTALLED_TF:
from deepmd.tf.fit.dos import DOSFitting as DOSFittingTF
else:
Expand Down Expand Up @@ -106,9 +112,14 @@ def skip_jax(self) -> bool:
def skip_array_api_strict(self) -> bool:
return not INSTALLED_ARRAY_API_STRICT

@property
def skip_pt_expt(self) -> bool:
return CommonTest.skip_pt_expt

tf_class = DOSFittingTF
dp_class = DOSFittingDP
pt_class = DOSFittingPT
pt_expt_class = DOSFittingPTExpt
jax_class = DOSFittingJAX
array_api_strict_class = DOSFittingStrict
args = fitting_dos()
Expand Down Expand Up @@ -187,6 +198,31 @@ def eval_pt(self, pt_obj: Any) -> Any:
.numpy()
)

def eval_pt_expt(self, pt_expt_obj: Any) -> Any:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_aparam,
numb_dos,
) = self.param
return (
pt_expt_obj(
torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE),
torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE),
fparam=torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE)
if numb_fparam
else None,
aparam=torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE)
if numb_aparam
else None,
)["dos"]
.detach()
.cpu()
.numpy()
)

def eval_dp(self, dp_obj: Any) -> Any:
(
resnet_dt,
Expand Down
25 changes: 25 additions & 0 deletions source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_PT_EXPT,
INSTALLED_TF,
CommonTest,
parameterized,
Expand All @@ -33,6 +34,13 @@
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
else:
PolarFittingPT = object
if INSTALLED_PT_EXPT:
from deepmd.pt_expt.fitting.polarizability_fitting import (
PolarFitting as PolarFittingPTExpt,
)
from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE
else:
PolarFittingPTExpt = None
if INSTALLED_TF:
from deepmd.tf.fit.polar import PolarFittingSeA as PolarFittingTF
else:
Expand Down Expand Up @@ -90,12 +98,17 @@ def skip_pt(self) -> bool:
tf_class = PolarFittingTF
dp_class = PolarFittingDP
pt_class = PolarFittingPT
pt_expt_class = PolarFittingPTExpt
jax_class = PolarFittingJAX
array_api_strict_class = PolarFittingArrayAPIStrict
args = fitting_polar()
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

@property
def skip_pt_expt(self) -> bool:
return CommonTest.skip_pt_expt

def setUp(self) -> None:
CommonTest.setUp(self)

Expand Down Expand Up @@ -155,6 +168,18 @@ def eval_pt(self, pt_obj: Any) -> Any:
.numpy()
)

def eval_pt_expt(self, pt_expt_obj: Any) -> Any:
return (
pt_expt_obj(
torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE),
torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE),
gr=torch.from_numpy(self.gr).to(device=PT_EXPT_DEVICE),
)["polarizability"]
.detach()
.cpu()
.numpy()
)

def eval_dp(self, dp_obj: Any) -> Any:
(
resnet_dt,
Expand Down
39 changes: 39 additions & 0 deletions source/tests/consistent/fitting/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
INSTALLED_ARRAY_API_STRICT,
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_PT_EXPT,
CommonTest,
parameterized,
)
Expand All @@ -37,6 +38,13 @@
from deepmd.pt.utils.env import DEVICE as PT_DEVICE
else:
PropertyFittingPT = object
if INSTALLED_PT_EXPT:
from deepmd.pt_expt.fitting.property_fitting import (
PropertyFittingNet as PropertyFittingPTExpt,
)
from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE
else:
PropertyFittingPTExpt = None
if INSTALLED_JAX:
from deepmd.jax.env import (
jnp,
Expand Down Expand Up @@ -110,9 +118,14 @@ def skip_tf(self) -> bool:
skip_jax = not INSTALLED_JAX
skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT

@property
def skip_pt_expt(self) -> bool:
return CommonTest.skip_pt_expt

tf_class = PropertyFittingTF
dp_class = PropertyFittingDP
pt_class = PropertyFittingPT
pt_expt_class = PropertyFittingPTExpt
jax_class = PropertyFittingJAX
array_api_strict_class = PropertyFittingStrict
args = fitting_property()
Expand Down Expand Up @@ -194,6 +207,32 @@ def eval_pt(self, pt_obj: Any) -> Any:
.numpy()
)

def eval_pt_expt(self, pt_expt_obj: Any) -> Any:
(
resnet_dt,
precision,
mixed_types,
numb_fparam,
numb_aparam,
task_dim,
intensive,
) = self.param
return (
pt_expt_obj(
torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE),
torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE),
fparam=torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE)
if numb_fparam
else None,
aparam=torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE)
if numb_aparam
else None,
)[pt_expt_obj.var_name]
.detach()
.cpu()
.numpy()
)

def eval_dp(self, dp_obj: Any) -> Any:
(
resnet_dt,
Expand Down
Loading