Skip to content

Commit bf3e483

Browse files
Fix: dptest label mismatch (#5082)
Currently, when running dptest on atomic polar. The order of atoms for `coord` and `label` are misaligned. This is a result of type_sel = None for coord, but not None for label. I tried to set it to None for label as well, but it broke TF compatibility. So the easiest hack I can think of is to add a check while loading data. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Enhanced output handling in test suite to improve per-atom count tracking in selection outputs. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent afb97f6 commit bf3e483

File tree

6 files changed

+528
-11
lines changed

6 files changed

+528
-11
lines changed

deepmd/entrypoints/test.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,7 @@ def test_polar(
11341134
must=True,
11351135
high_prec=False,
11361136
type_sel=dp.get_sel_type(),
1137+
output_natoms_for_type_sel=True,
11371138
)
11381139

11391140
test_data = data.get_test()
@@ -1155,7 +1156,12 @@ def test_polar(
11551156
polar = polar.reshape((polar.shape[0], -1, 9))[:, sel_mask, :].reshape(
11561157
(polar.shape[0], -1)
11571158
)
1158-
rmse_f = rmse(polar - test_data["atom_polarizability"][:numb_test])
1159+
label_polar = (
1160+
test_data["atom_polarizability"][:numb_test]
1161+
.reshape((numb_test, -1, 9))[:, sel_mask, :]
1162+
.reshape((numb_test, -1))
1163+
)
1164+
rmse_f = rmse(polar - label_polar)
11591165

11601166
log.info(f"# number of test data : {numb_test:d} ")
11611167
log.info(f"Polarizability RMSE : {rmse_f:e}")
@@ -1183,10 +1189,7 @@ def test_polar(
11831189
else:
11841190
pe = np.concatenate(
11851191
(
1186-
np.reshape(
1187-
test_data["atom_polarizability"][:numb_test],
1188-
[-1, 9 * sel_natoms],
1189-
),
1192+
np.reshape(label_polar, [-1, 9 * sel_natoms]),
11901193
np.reshape(polar, [-1, 9 * sel_natoms]),
11911194
),
11921195
axis=1,
@@ -1275,7 +1278,9 @@ def test_dipole(
12751278
must=True,
12761279
high_prec=False,
12771280
type_sel=dp.get_sel_type(),
1281+
output_natoms_for_type_sel=True,
12781282
)
1283+
12791284
test_data = data.get_test()
12801285
dipole, numb_test, atype = run_test(dp, test_data, numb_test, data)
12811286

@@ -1295,7 +1300,12 @@ def test_dipole(
12951300
dipole = dipole.reshape((dipole.shape[0], -1, 3))[:, sel_mask, :].reshape(
12961301
(dipole.shape[0], -1)
12971302
)
1298-
rmse_f = rmse(dipole - test_data["atom_dipole"][:numb_test])
1303+
label_dipole = (
1304+
test_data["atom_dipole"][:numb_test]
1305+
.reshape((numb_test, -1, 3))[:, sel_mask, :]
1306+
.reshape((numb_test, -1))
1307+
)
1308+
rmse_f = rmse(dipole - label_dipole)
12991309

13001310
log.info(f"# number of test data : {numb_test:d}")
13011311
log.info(f"Dipole RMSE : {rmse_f:e}")
@@ -1318,9 +1328,7 @@ def test_dipole(
13181328
else:
13191329
pe = np.concatenate(
13201330
(
1321-
np.reshape(
1322-
test_data["atom_dipole"][:numb_test], [-1, 3 * sel_natoms]
1323-
),
1331+
np.reshape(label_dipole, [-1, 3 * sel_natoms]),
13241332
np.reshape(dipole, [-1, 3 * sel_natoms]),
13251333
),
13261334
axis=1,

deepmd/utils/data.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def add(
188188
output_natoms_for_type_sel : bool, optional
189189
if True and type_sel is True, the atomic dimension will be natoms instead of nsel
190190
"""
191+
# normalize key: "atomic_" prefix -> "atom_", same convention as _load_set output
192+
if key.startswith("atomic_"):
193+
key = "atom_" + key[7:]
191194
self.data_dict[key] = {
192195
"ndof": ndof,
193196
"atomic": atomic,
@@ -762,6 +765,15 @@ def _load_set(self, set_name: DPPath) -> dict[str, Any]:
762765
data = {kk.replace("atomic", "atom"): vv for kk, vv in data.items()}
763766
return data
764767

768+
def _get_data_path(self, set_name: "DPPath", key: str) -> "DPPath":
769+
"""Return the path for a data file, trying both atom_ and atomic_ naming."""
770+
path = set_name / (key + ".npy")
771+
if not path.is_file() and key.startswith("atom_"):
772+
alt = set_name / ("atomic_" + key[5:] + ".npy")
773+
if alt.is_file():
774+
return alt
775+
return path
776+
765777
def _load_data(
766778
self,
767779
set_name: str,
@@ -800,7 +812,7 @@ def _load_data(
800812
dtype = GLOBAL_ENER_FLOAT_PRECISION
801813
else:
802814
dtype = GLOBAL_NP_FLOAT_PRECISION
803-
path = set_name / (key + ".npy")
815+
path = self._get_data_path(set_name, key)
804816
if path.is_file():
805817
data = path.load_numpy().astype(dtype)
806818
try: # YWolfeee: deal with data shape error
@@ -892,7 +904,7 @@ def _load_single_data(
892904
The total number of frames in this set (to avoid redundant _get_nframes calls)
893905
"""
894906
vv = self.data_dict[key]
895-
path = set_dir / (key + ".npy")
907+
path = self._get_data_path(set_dir, key)
896908

897909
if vv["atomic"]:
898910
natoms = self.natoms

source/tests/pt/model/test_dipole_fitting.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,161 @@ def test_deepdipole_infer(self) -> None:
372372
load_md.eval_full(coords=coord, atom_types=atype, cells=cell, atomic=True)
373373
load_md.eval_full(coords=coord, atom_types=atype, cells=cell, atomic=False)
374374

375+
def test_eval_shuffle_sel_type(self) -> None:
376+
# Build a model where only type-0 atoms contribute (exclude types 1 and 2).
377+
# This tests that eval() returns per-atom results in the correct input atom
378+
# order even when sel_type is a strict subset of all types.
379+
ft_sel = DipoleFittingNet(
380+
self.nt,
381+
self.dd0.dim_out,
382+
embedding_width=self.dd0.get_dim_emb(),
383+
numb_fparam=0,
384+
numb_aparam=0,
385+
mixed_types=self.dd0.mixed_types(),
386+
exclude_types=[1, 2],
387+
seed=GLOBAL_SEED,
388+
).to(env.DEVICE)
389+
model_sel = DipoleModel(self.dd0, ft_sel, self.type_mapping)
390+
jit_md = torch.jit.script(model_sel)
391+
torch.jit.save(jit_md, self.file_path)
392+
load_md = DeepDipole(self.file_path)
393+
394+
atype = to_numpy_array(self.atype) # [0, 0, 0, 1, 1]
395+
coord = to_numpy_array(self.coord.reshape(1, self.natoms, 3))
396+
cell = to_numpy_array(self.cell.reshape(1, 9))
397+
398+
# Reference result with original atom order
399+
ref = load_md.eval(coords=coord, atom_types=atype, cells=cell, atomic=True)
400+
# ref shape: [nframes, natoms, nout]
401+
402+
# Shuffle atoms
403+
idx_perm = np.array([1, 0, 4, 3, 2], dtype=np.intp)
404+
coord_sf = coord.reshape(self.natoms, 3)[idx_perm].reshape(1, -1)
405+
atype_sf = atype[idx_perm]
406+
407+
# Result with shuffled atom order
408+
res_sf = load_md.eval(
409+
coords=coord_sf, atom_types=atype_sf, cells=cell, atomic=True
410+
)
411+
# res_sf shape: [nframes, natoms, nout]
412+
413+
# sel_mask: which atoms in the original order are selected (type 0)
414+
sel_mask = np.isin(atype, load_md.get_sel_type()) # [T,T,T,F,F]
415+
sel_mask_sf = sel_mask[idx_perm] # selected atoms in shuffled order
416+
417+
# Extract selected-atom outputs from each result
418+
ref_sel = ref[:, sel_mask] # [nframes, nsel, nout]
419+
at_sf = res_sf[:, sel_mask_sf] # [nframes, nsel, nout]
420+
421+
# isel_sf: mapping from shuffled-selected positions to original-selected positions
422+
orig_sel_idx = np.where(sel_mask)[0]
423+
shuffled_orig = np.array(idx_perm)[sel_mask_sf]
424+
isel_sf = np.array(
425+
[np.where(orig_sel_idx == x)[0][0] for x in shuffled_orig]
426+
) # [1, 0, 2]
427+
428+
# Recover original selected order from shuffled selected
429+
nat = np.empty_like(at_sf)
430+
nat[:, isel_sf] = at_sf
431+
432+
np.testing.assert_almost_equal(
433+
nat.reshape([-1]), ref_sel.reshape([-1]), decimal=10
434+
)
435+
436+
def test_label_order_via_deepmd_data(self) -> None:
437+
"""Verify that labels loaded via DeepmdData(sort_atoms=False) +
438+
output_natoms_for_type_sel=True align with dp.eval() output.
439+
Uses a sel_type model (exclude_types=[1,2]) with atype=[0,0,0,1,1]
440+
shuffled to [0,0,1,1,0] so selected atoms are non-contiguous.
441+
"""
442+
import shutil
443+
import tempfile
444+
445+
from deepmd.utils.data import (
446+
DeepmdData,
447+
)
448+
449+
ft_sel = DipoleFittingNet(
450+
self.nt,
451+
self.dd0.dim_out,
452+
embedding_width=self.dd0.get_dim_emb(),
453+
numb_fparam=0,
454+
numb_aparam=0,
455+
mixed_types=self.dd0.mixed_types(),
456+
exclude_types=[1, 2],
457+
seed=GLOBAL_SEED,
458+
).to(env.DEVICE)
459+
model_sel = DipoleModel(self.dd0, ft_sel, self.type_mapping)
460+
jit_md = torch.jit.script(model_sel)
461+
torch.jit.save(jit_md, self.file_path)
462+
load_md = DeepDipole(self.file_path)
463+
464+
# Shuffle atoms so selected type-0 atoms are non-contiguous
465+
# atype=[0,0,0,1,1] → shuffled idx → atype=[0,0,1,1,0]
466+
idx_perm = np.array([1, 0, 4, 3, 2], dtype=np.intp)
467+
atype = to_numpy_array(self.atype) # [0,0,0,1,1]
468+
coord = to_numpy_array(self.coord.reshape(1, self.natoms, 3))
469+
cell = to_numpy_array(self.cell.reshape(1, 9))
470+
atype_sf = atype[idx_perm]
471+
coord_sf = coord.reshape(self.natoms, 3)[idx_perm].reshape(1, -1)
472+
473+
sel_mask_sf = np.isin(atype_sf, load_md.get_sel_type()) # type-0 positions
474+
475+
# Reference: model output for shuffled atoms, filter to sel atoms
476+
ref_sf = load_md.eval(
477+
coords=coord_sf, atom_types=atype_sf, cells=cell, atomic=True
478+
) # [1, natoms, nout]
479+
ref_sf_sel = ref_sf[:, sel_mask_sf, :] # [1, nsel, nout]
480+
481+
tmpdir = tempfile.mkdtemp()
482+
try:
483+
set_dir = os.path.join(tmpdir, "set.000")
484+
os.makedirs(set_dir)
485+
np.savetxt(os.path.join(tmpdir, "type.raw"), atype_sf, fmt="%d")
486+
np.save(
487+
os.path.join(set_dir, "coord.npy"),
488+
coord_sf.reshape(1, -1),
489+
)
490+
np.save(
491+
os.path.join(set_dir, "box.npy"),
492+
cell.reshape(1, -1),
493+
)
494+
# Labels: nsel atoms in shuffled atom order (nsel format)
495+
np.save(
496+
os.path.join(set_dir, "atomic_dipole.npy"),
497+
ref_sf_sel.reshape(1, -1),
498+
)
499+
500+
data = DeepmdData(
501+
tmpdir,
502+
set_prefix="set",
503+
shuffle_test=False,
504+
type_map=load_md.get_type_map(),
505+
sort_atoms=False,
506+
)
507+
data.add(
508+
"atomic_dipole",
509+
3,
510+
atomic=True,
511+
must=True,
512+
high_prec=False,
513+
type_sel=load_md.get_sel_type(),
514+
output_natoms_for_type_sel=True,
515+
)
516+
test_data = data.get_test()
517+
518+
# Loaded label shape: [1, natoms*3]. Filter to sel atoms.
519+
label_sel = test_data["atom_dipole"].reshape(1, self.natoms, 3)[
520+
:, sel_mask_sf, :
521+
] # [1, nsel, 3]
522+
523+
# Round-trip: loaded label must match what was written
524+
np.testing.assert_almost_equal(
525+
label_sel.reshape(-1), ref_sf_sel.reshape(-1), decimal=5
526+
)
527+
finally:
528+
shutil.rmtree(tmpdir)
529+
375530
def tearDown(self) -> None:
376531
if os.path.exists(self.file_path):
377532
os.remove(self.file_path)

0 commit comments

Comments
 (0)