Skip to content

Commit 9421022

Browse files
committed
fix(pt): resolve issues with dipole charge modifier
- Use pickle instead of json for modifier serialization to handle np.ndarray - Fix device placement in EwaldReal class by explicitly setting device and dtype - Add proper error handling in test case to ensure torch default dtype is restored
1 parent 99287fb commit 9421022

3 files changed

Lines changed: 26 additions & 23 deletions

File tree

deepmd/pt/entrypoints/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@ def freeze(
404404
extra_files = {"modifier_data": ""}
405405
dm = tester.modifier
406406
if dm is not None:
407+
# dict from dm.serialize() includes np.ndarray
408+
# use pickle rather than json
407409
bytes_data = pickle.dumps(dm.serialize())
408410
extra_files = {"modifier_data": bytes_data}
409411
torch.jit.save(

deepmd/pt/modifier/dipole_charge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
rspace=False,
8686
kappa=ewald_beta,
8787
spacing=ewald_h,
88-
).to(env.GLOBAL_PT_FLOAT_PRECISION)
88+
).to(device=env.DEVICE, dtype=env.GLOBAL_PT_FLOAT_PRECISION)
8989
self.er = torch.jit.script(er)
9090
self.er.eval()
9191
self.placeholder_pairs = torch.ones((1, 2), device=env.DEVICE, dtype=torch.long)

source/tests/pt/modifier/test_dipole_charge.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -160,30 +160,31 @@ def test_consistency(self):
160160
dtype = torch.get_default_dtype()
161161
torch.set_default_dtype(torch.float64)
162162

163-
coord, box, atype = ref_data()
163+
try:
164+
coord, box, atype = ref_data()
164165

165-
pt_data = self.dm_pt.eval_np(
166-
coord=coord,
167-
atype=atype,
168-
box=box,
169-
)
170-
tf_data = self.dm_tf.eval(
171-
coord=coord,
172-
box=box,
173-
atype=atype.reshape(-1),
174-
)
175-
tol = 1e-6
176-
output_names = ["energy", "force", "virial"]
177-
for ii, name in enumerate(output_names):
178-
np.testing.assert_allclose(
179-
pt_data[ii].reshape(-1),
180-
tf_data[ii].reshape(-1),
181-
atol=tol,
182-
rtol=tol,
183-
err_msg=f"Mismatch in {name}",
166+
pt_data = self.dm_pt.eval_np(
167+
coord=coord,
168+
atype=atype,
169+
box=box,
184170
)
185-
186-
torch.set_default_dtype(dtype)
171+
tf_data = self.dm_tf.eval(
172+
coord=coord,
173+
box=box,
174+
atype=atype.reshape(-1),
175+
)
176+
tol = 1e-6
177+
output_names = ["energy", "force", "virial"]
178+
for ii, name in enumerate(output_names):
179+
np.testing.assert_allclose(
180+
pt_data[ii].reshape(-1),
181+
tf_data[ii].reshape(-1),
182+
atol=tol,
183+
rtol=tol,
184+
err_msg=f"Mismatch in {name}",
185+
)
186+
finally:
187+
torch.set_default_dtype(dtype)
187188

188189
def test_serialize(self):
189190
"""Test the serialize method of DipoleChargeModifier."""

0 commit comments

Comments
 (0)