Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions deepmd/dpmodel/model/ener_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from copy import (
deepcopy,
)

from deepmd.dpmodel.atomic_model import (
DPEnergyAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
)

from .dp_model import (
DPModelCommon,
Expand All @@ -25,3 +32,15 @@ def __init__(
) -> None:
DPModelCommon.__init__(self)
DPEnergyModel_.__init__(self, *args, **kwargs)
self._enable_hessian = False
self.hess_fitting_def = None

def enable_hessian(self):
self.hess_fitting_def = deepcopy(self.atomic_output_def())
self.hess_fitting_def["energy"].r_hessian = True
self._enable_hessian = True

def atomic_output_def(self) -> FittingOutputDef:
if self._enable_hessian:
return self.hess_fitting_def
return super().atomic_output_def()
92 changes: 92 additions & 0 deletions deepmd/dpmodel/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ModelOutputDef,
OutputVariableDef,
get_deriv_name,
get_hessian_name,
get_reduce_name,
)

Expand Down Expand Up @@ -81,6 +82,7 @@ def communicate_extended_output(

"""
xp = array_api_compat.get_namespace(mapping)
mapping_ = mapping
new_ret = {}
for kk in model_output_def.keys_outp():
vv = model_ret[kk]
Expand Down Expand Up @@ -116,6 +118,96 @@ def communicate_extended_output(
else:
# name holders
new_ret[kk_derv_r] = None
if vdef.r_hessian:
kk_hess = get_hessian_name(kk)
if model_ret[kk_hess] is not None:
# jax only
if array_api_compat.is_jax_array(force):
Comment thread
njzjz marked this conversation as resolved.
Outdated
from deepmd.jax.common import (
scatter_sum,
)
from deepmd.jax.env import (
jnp,
)

# [nf, *def, nall, 3, nall, 3]
hess_ = model_ret[kk_hess]
def_ndim = len(vdef.shape)
# [nf, nall1, nall2, *def, 3(1), 3(2)]
hess_1 = jnp.transpose(
hess_,
(
0,
def_ndim + 1,
def_ndim + 3,
*range(1, def_ndim + 1),
def_ndim + 2,
def_ndim + 4,
),
)
nall = hess_1.shape[1]
# (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)]
hessian1 = jnp.zeros(
[*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype
)
mapping_hess = xp.reshape(
mapping_, (mldims + [1] * (len(vdef.shape) + 3))
)
mapping_hess = xp.tile(
mapping_hess,
[1] * len(mldims) + [nall, *vdef.shape, 3, 3],
)
hessian1 = scatter_sum(
hessian1,
1,
mapping_hess,
hess_1,
)
# [nf, nall2, nloc1, *def, 3(1), 3(2)]
hessian1 = jnp.transpose(
hessian1,
(0, 2, 1, *range(3, def_ndim + 5)),
)
nloc = hessian1.shape[2]
# (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)]
hessian = jnp.zeros(
[*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype
)
mapping_hess = xp.reshape(
mapping_, (mldims + [1] * (len(vdef.shape) + 3))
)
mapping_hess = xp.tile(
mapping_hess,
[1] * len(mldims) + [nloc, *vdef.shape, 3, 3],
)
hessian = scatter_sum(
hessian,
1,
mapping_hess,
hessian1,
)
# -> [nf, *def, nloc1, 3(1), nloc2, 3(2)]
hessian = jnp.transpose(
hessian,
(
0,
*range(3, def_ndim + 3),
2,
def_ndim + 3,
1,
def_ndim + 4,
),
)
# -> [nf, *def, nloc1 * 3, nloc2 * 3]
hessian = jnp.reshape(
hessian,
(hessian.shape[0], *vdef.shape, nloc * 3, nloc * 3),
)
else:
raise NotImplementedError("Only JAX arrays are supported.")
new_ret[kk_hess] = hessian
else:
new_ret[kk_hess] = None
Comment thread
njzjz marked this conversation as resolved.
if vdef.c_differentiable:
assert vdef.r_differentiable
if model_ret[kk_derv_c] is not None:
Expand Down
13 changes: 13 additions & 0 deletions deepmd/jax/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from deepmd.dpmodel.output_def import (
get_deriv_name,
get_hessian_name,
get_reduce_name,
)
from deepmd.jax.env import (
Expand Down Expand Up @@ -87,6 +88,18 @@ def eval_output(
)

model_predict[kk_derv_r] = extended_force
if vdef.r_hessian:
# [nf, *def, nall, 3, nall, 3]
hessian = jax.vmap(jax.hessian(eval_output, argnums=0))(
extended_coord,
extended_atype,
nlist,
mapping,
fparam,
aparam,
)
kk_hessian = get_hessian_name(kk)
model_predict[kk_hessian] = hessian
if vdef.c_differentiable:
assert vdef.r_differentiable
# avr: [nf, *def, nall, 3, 3]
Expand Down
106 changes: 106 additions & 0 deletions source/tests/jax/test_dp_hessian_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.jax.common import (
to_jax_array,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.jax.env import (
jnp,
)
from deepmd.jax.fitting.fitting import (
EnergyFittingNet,
)
from deepmd.jax.model.ener_model import (
EnergyModel,
)

dtype = jnp.float64


class TestCaseSingleFrameWithoutNlist:
def setUp(self) -> None:
# nloc == 3, nall == 4
self.nloc = 3
self.nf, self.nt = 1, 2
self.coord = np.array(
[
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
dtype=np.float64,
).reshape([1, self.nloc * 3])
self.atype = np.array([0, 0, 1], dtype=int).reshape([1, self.nloc])
self.cell = 2.0 * np.eye(3).reshape([1, 9])
# sel = [5, 2]
self.sel = [16, 8]
self.sel_mix = [24]
self.natoms = [3, 3, 2, 1]
self.rcut = 2.2
self.rcut_smth = 0.4
self.atol = 1e-12


class TestEnergyHessianModel(unittest.TestCase, TestCaseSingleFrameWithoutNlist):
def setUp(self):
TestCaseSingleFrameWithoutNlist.setUp(self)

def test_self_consistency(self):
ds = DescrptSeA(
self.rcut,
self.rcut_smth,
self.sel,
)
ft = EnergyFittingNet(
self.nt,
ds.get_dim_out(),
mixed_types=ds.mixed_types(),
)
type_map = ["foo", "bar"]
md0 = EnergyModel(ds, ft, type_map=type_map)
md1 = EnergyModel.deserialize(md0.serialize())
md0.enable_hessian()
md1.enable_hessian()
args = [to_jax_array(ii) for ii in [self.coord, self.atype, self.cell]]
ret0 = md0.call(*args)
ret1 = md1.call(*args)
np.testing.assert_allclose(
to_numpy_array(ret0["energy"]),
to_numpy_array(ret1["energy"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_redu"]),
to_numpy_array(ret1["energy_redu"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_r"]),
to_numpy_array(ret1["energy_derv_r"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c_redu"]),
to_numpy_array(ret1["energy_derv_c_redu"]),
atol=self.atol,
)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_r_derv_r"]),
to_numpy_array(ret1["energy_derv_r_derv_r"]),
atol=self.atol,
)
ret0 = md0.call(*args, do_atomic_virial=True)
ret1 = md1.call(*args, do_atomic_virial=True)
np.testing.assert_allclose(
to_numpy_array(ret0["energy_derv_c"]),
to_numpy_array(ret1["energy_derv_c"]),
atol=self.atol,
)
Loading