Skip to content

Commit 409fcad

Browse files
committed
refactor(pt): refactor DipoleChargeModifier to inherit from DPModifier
- Refactor DipoleChargeModifier to inherit from DPModifier instead of BaseModifier - Simplify model initialization and batch processing logic - Update torch-admp dependency version to >=1.1.4 - Update tests to match new API and add dipole consistency test
1 parent 7fd4801 commit 409fcad

3 files changed

Lines changed: 103 additions & 98 deletions

File tree

deepmd/pt/modifier/dipole_charge.py

Lines changed: 45 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
import os
32

43
import numpy as np
54
import torch
@@ -16,20 +15,20 @@
1615
from deepmd.pt.modifier.base_modifier import (
1716
BaseModifier,
1817
)
18+
from deepmd.pt.modifier.dp_modifier import (
19+
DPModifier,
20+
)
1921
from deepmd.pt.utils import (
2022
env,
2123
)
22-
from deepmd.pt.utils.serialization import (
23-
serialize_from_file,
24-
)
2524
from deepmd.pt.utils.utils import (
2625
to_numpy_array,
2726
to_torch_tensor,
2827
)
2928

3029

3130
@BaseModifier.register("dipole_charge")
32-
class DipoleChargeModifier(BaseModifier):
31+
class DipoleChargeModifier(DPModifier):
3332
"""Parameters
3433
----------
3534
model_name
@@ -46,36 +45,22 @@ class DipoleChargeModifier(BaseModifier):
4645

4746
def __init__(
4847
self,
49-
model_name: str | None,
48+
dp_model: DipoleModel | None,
5049
model_charge_map: list[float],
5150
sys_charge_map: list[float],
5251
ewald_h: float = 1.0,
5352
ewald_beta: float = 1.0,
54-
ewald_batch_size: int = 5,
55-
dp_batch_size: int | None = None,
56-
model: DipoleModel | None = None,
53+
dp_model_file_name: str | None = None,
5754
use_cache: bool = True,
5855
) -> None:
5956
"""Constructor."""
60-
super().__init__(use_cache=use_cache)
6157
self.modifier_type = "dipole_charge"
58+
super().__init__(
59+
dp_model=dp_model,
60+
dp_model_file_name=dp_model_file_name,
61+
use_cache=use_cache,
62+
)
6263

63-
if model_name is None and model is None:
64-
raise AttributeError("`model_name` or `model` should be specified.")
65-
if model_name is not None and model is not None:
66-
raise AttributeError(
67-
"`model_name` and `model` cannot be used simultaneously."
68-
)
69-
70-
if model is not None:
71-
self._model = model.to(env.DEVICE)
72-
if model_name is not None:
73-
data = serialize_from_file(model_name)
74-
self._model = DipoleModel.deserialize(data["model"]).to(env.DEVICE)
75-
self._model.eval()
76-
77-
# use jit model for inference
78-
self.model = torch.jit.script(self._model)
7964
self.rcut = self.model.get_rcut()
8065
self.type_map = self.model.get_type_map()
8166
sel_type = self.model.get_sel_type()
@@ -95,23 +80,22 @@ def __init__(
9580
# init ewald recp
9681
self.ewald_h = ewald_h
9782
self.ewald_beta = ewald_beta
98-
self.er = CoulombForceModule(
83+
er = CoulombForceModule(
9984
rcut=self.rcut,
10085
rspace=False,
10186
kappa=ewald_beta,
10287
spacing=ewald_h,
103-
)
88+
).to(env.GLOBAL_PT_FLOAT_PRECISION)
89+
self.er = torch.jit.script(er)
90+
self.er.eval()
10491
self.placeholder_pairs = torch.ones((1, 2), device=env.DEVICE, dtype=torch.long)
105-
self.placeholder_ds = torch.ones((1), device=env.DEVICE, dtype=torch.float64)
92+
self.placeholder_ds = torch.ones(
93+
(1), device=env.DEVICE, dtype=env.GLOBAL_PT_FLOAT_PRECISION
94+
)
10695
self.placeholder_buffer_scales = torch.zeros(
107-
(1), device=env.DEVICE, dtype=torch.float64
96+
(1), device=env.DEVICE, dtype=env.GLOBAL_PT_FLOAT_PRECISION
10897
)
10998

110-
self.ewald_batch_size = ewald_batch_size
111-
if dp_batch_size is None:
112-
dp_batch_size = int(os.environ.get("DP_INFER_BATCH_SIZE", 1))
113-
self.dp_batch_size = dp_batch_size
114-
11599
def serialize(self) -> dict:
116100
"""Serialize the modifier.
117101
@@ -120,16 +104,13 @@ def serialize(self) -> dict:
120104
dict
121105
The serialized data
122106
"""
123-
dd = BaseModifier.serialize(self)
107+
dd = super().serialize()
124108
dd.update(
125109
{
126-
"model": self._model.serialize(),
127110
"model_charge_map": self._model_charge_map,
128111
"sys_charge_map": self._sys_charge_map,
129112
"ewald_h": self.ewald_h,
130113
"ewald_beta": self.ewald_beta,
131-
"ewald_batch_size": self.ewald_batch_size,
132-
"dp_batch_size": self.dp_batch_size,
133114
}
134115
)
135116
return dd
@@ -140,9 +121,9 @@ def deserialize(cls, data: dict) -> "DipoleChargeModifier":
140121
data.pop("@class", None)
141122
data.pop("type", None)
142123
data.pop("@version", None)
143-
model_obj = DipoleModel.deserialize(data.pop("model"))
144-
data["model"] = model_obj
145-
data["model_name"] = None
124+
model_obj = DipoleModel.deserialize(data.pop("dp_model"))
125+
data["dp_model"] = model_obj
126+
data["dp_model_file_name"] = None
146127
return cls(**data)
147128

148129
def forward(
@@ -213,30 +194,23 @@ def forward(
213194
)
214195

215196
# add Ewald reciprocal correction
216-
tot_e: list[torch.Tensor] = []
217-
chunk_coord = torch.split(
218-
extended_coord.reshape(nframes, -1, 3), self.dp_batch_size, dim=0
197+
placeholder_pairs = torch.tile(
198+
self.placeholder_pairs.unsqueeze(0), (nframes, 1, 1)
219199
)
220-
chunk_box = torch.split(
221-
input_box.reshape(nframes, 3, 3), self.dp_batch_size, dim=0
200+
placeholder_ds = torch.tile(self.placeholder_ds.unsqueeze(0), (nframes, 1))
201+
placeholder_buffer_scales = torch.tile(
202+
self.placeholder_buffer_scales.unsqueeze(0), (nframes, 1)
222203
)
223-
chunk_charge = torch.split(
224-
extended_charge.reshape(nframes, -1), self.dp_batch_size, dim=0
204+
self.er(
205+
extended_coord.reshape(nframes, 2 * natoms, 3),
206+
input_box.reshape(nframes, 3, 3),
207+
placeholder_pairs,
208+
placeholder_ds,
209+
placeholder_buffer_scales,
210+
{"charge": extended_charge.reshape(nframes, 2 * natoms)},
225211
)
226-
for _coord, _box, _charge in zip(
227-
chunk_coord, chunk_box, chunk_charge, strict=True
228-
):
229-
self.er(
230-
_coord,
231-
_box,
232-
self.placeholder_pairs,
233-
self.placeholder_ds,
234-
self.placeholder_buffer_scales,
235-
{"charge": _charge},
236-
)
237-
tot_e.append(self.er.reciprocal_energy.unsqueeze(0))
238212
# nframe,
239-
tot_e = torch.concat(tot_e, dim=0)
213+
tot_e = self.er.reciprocal_energy
240214
# nframe, nat * 3
241215
tot_f = -calc_grads(tot_e, input_coord)
242216
# nframe, nat, 3
@@ -346,37 +320,17 @@ def extend_system_coord(
346320
nframes = coord.shape[0]
347321
natoms = coord.shape[1] // 3
348322

349-
all_dipole: list[torch.Tensor] = []
350-
chunk_coord = torch.split(coord, self.dp_batch_size, dim=0)
351-
chunk_atype = torch.split(atype, self.dp_batch_size, dim=0)
352-
chunk_box = torch.split(box, self.dp_batch_size, dim=0)
353-
# use placeholder to make the jit happy for fparam/aparam is None
354-
chunk_fparam = (
355-
torch.split(fparam, self.dp_batch_size, dim=0)
356-
if fparam is not None
357-
else chunk_atype
358-
)
359-
chunk_aparam = (
360-
torch.split(aparam, self.dp_batch_size, dim=0)
361-
if aparam is not None
362-
else chunk_atype
323+
model_pred = self.model(
324+
coord=coord,
325+
atype=atype,
326+
box=box,
327+
do_atomic_virial=False,
328+
fparam=fparam if fparam is not None else None,
329+
aparam=aparam if aparam is not None else None,
363330
)
364-
for _coord, _atype, _box, _fparam, _aparam in zip(
365-
chunk_coord, chunk_atype, chunk_box, chunk_fparam, chunk_aparam, strict=True
366-
):
367-
dipole_batch = self.model(
368-
coord=_coord,
369-
atype=_atype,
370-
box=_box,
371-
do_atomic_virial=False,
372-
fparam=_fparam if fparam is not None else None,
373-
aparam=_aparam if aparam is not None else None,
374-
)
375-
# Extract dipole from the output dictionary
376-
all_dipole.append(dipole_batch["dipole"])
377331

378-
# nframe x natoms x 3
379-
dipole = torch.cat(all_dipole, dim=0)
332+
# nframe, natoms, 3
333+
dipole = model_pred["dipole"]
380334
if dipole.shape[0] != nframes:
381335
raise RuntimeError(
382336
f"Dipole shape mismatch: expected {nframes} frames, got {dipole.shape[0]}"

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,11 @@ pin_pytorch_cpu = [
168168
# macos x86 has been deprecated
169169
"torch>=2.8,<2.10; platform_machine!='x86_64' or platform_system != 'Darwin'",
170170
"torch; platform_machine=='x86_64' and platform_system == 'Darwin'",
171-
"torch-admp==1.1.3",
171+
"torch-admp>=1.1.4",
172172
]
173173
pin_pytorch_gpu = [
174174
"torch>=2.7,<2.10",
175-
"torch-admp==1.1.3",
175+
"torch-admp>=1.1.4",
176176
]
177177
pin_jax = [
178178
"jax==0.5.0;python_version>='3.10'",

source/tests/pt/modifier/test_dipole_charge.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
from deepmd.entrypoints.convert_backend import (
1414
convert_backend,
1515
)
16+
from deepmd.env import (
17+
GLOBAL_NP_FLOAT_PRECISION,
18+
)
19+
from deepmd.infer import (
20+
DeepEval,
21+
)
1622
from deepmd.pt.entrypoints.main import (
1723
freeze,
1824
get_trainer,
@@ -41,8 +47,8 @@ def ref_data():
4147
rng = np.random.default_rng(GLOBAL_SEED)
4248
selected_id = rng.integers(nframe)
4349

44-
coord = all_coord[selected_id].reshape(1, -1)
45-
box = all_box[selected_id].reshape(1, -1)
50+
coord = all_coord[selected_id].reshape(1, -1).astype(GLOBAL_NP_FLOAT_PRECISION)
51+
box = all_box[selected_id].reshape(1, -1).astype(GLOBAL_NP_FLOAT_PRECISION)
4652
atype = np.loadtxt(
4753
str(Path(__file__).parent / "water/data/data_0/type.raw"),
4854
dtype=int,
@@ -68,6 +74,14 @@ def setUp(self) -> None:
6874
"rcut": 4.00,
6975
"neuron": [6, 12, 24],
7076
}
77+
self.modifier_dict = {
78+
"type": "dipole_charge",
79+
"model_name": "dw_model.pth",
80+
"model_charge_map": self.model_charge_map,
81+
"sys_charge_map": self.sys_charge_map,
82+
"ewald_beta": self.ewald_beta,
83+
"ewald_h": self.ewald_h,
84+
}
7185

7286
# Train DW model
7387
input_json = str(Path(__file__).parent / "water_tensor/se_e2_a.json")
@@ -95,11 +109,12 @@ def setUp(self) -> None:
95109
convert_backend(INPUT="dw_model.pth", OUTPUT="dw_model.pb")
96110

97111
self.dm_pt = PTDipoleChargeModifier(
98-
"dw_model.pth",
112+
None,
99113
self.model_charge_map,
100114
self.sys_charge_map,
101115
self.ewald_h,
102116
self.ewald_beta,
117+
"dw_model.pth",
103118
)
104119
self.dm_tf = TFDipoleChargeModifier(
105120
"dw_model.pb",
@@ -112,7 +127,39 @@ def setUp(self) -> None:
112127
def test_jit(self):
113128
torch.jit.script(self.dm_pt)
114129

130+
def test_dipole_consistency(self):
131+
coord, box, atype = ref_data()
132+
tf_model = DeepEval("dw_model.pb")
133+
tf_data = tf_model.eval(
134+
coords=coord,
135+
cells=box,
136+
atom_types=atype.reshape(-1),
137+
)
138+
139+
nframes = 1
140+
input_box = to_torch_tensor(box).reshape(nframes, 9)
141+
input_coord = to_torch_tensor(coord).reshape(nframes, -1)
142+
input_atype = to_torch_tensor(atype).to(torch.long)
143+
_extended_coord, _extended_charge, atomic_dipole = self.dm_pt.extend_system(
144+
input_coord,
145+
input_atype,
146+
input_box,
147+
None,
148+
None,
149+
)
150+
151+
np.testing.assert_allclose(
152+
tf_data.reshape(-1, 3),
153+
to_numpy_array(atomic_dipole).reshape(-1, 3),
154+
)
155+
156+
@unittest.skipIf(
157+
env.GLOBAL_PT_FLOAT_PRECISION != torch.float64, "run only for double precision"
158+
)
115159
def test_consistency(self):
160+
dtype = torch.get_default_dtype()
161+
torch.set_default_dtype(torch.float64)
162+
116163
coord, box, atype = ref_data()
117164

118165
pt_data = self.dm_pt.eval_np(
@@ -125,16 +172,19 @@ def test_consistency(self):
125172
box=box,
126173
atype=atype.reshape(-1),
127174
)
175+
tol = 1e-6
128176
output_names = ["energy", "force", "virial"]
129177
for ii, name in enumerate(output_names):
130178
np.testing.assert_allclose(
131179
pt_data[ii].reshape(-1),
132180
tf_data[ii].reshape(-1),
133-
atol=1e-6,
134-
rtol=1e-6,
181+
atol=tol,
182+
rtol=tol,
135183
err_msg=f"Mismatch in {name}",
136184
)
137185

186+
torch.set_default_dtype(dtype)
187+
138188
def test_serialize(self):
139189
"""Test the serialize method of DipoleChargeModifier."""
140190
coord, box, atype = ref_data()
@@ -206,6 +256,7 @@ def test_train(self):
206256
]
207257
config["training"]["numb_steps"] = 1
208258

259+
config["model"]["modifier"] = self.modifier_dict
209260
trainer = get_trainer(config)
210261
trainer.run()
211262
# Verify model checkpoint was created

0 commit comments

Comments
 (0)