Skip to content

Commit ea76ad6

Browse files
committed
Update virial and add UT.
1 parent 3bcee0e commit ea76ad6

File tree

3 files changed

+101
-2
lines changed

3 files changed

+101
-2
lines changed

deepmd/tf/descriptor/se_a.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,23 @@ def build(
670670
atype_t = tf.concat([[self.ntypes], tf.reshape(self.atype, [-1])], axis=0)
671671
self.nei_type_vec = tf.nn.embedding_lookup(atype_t, nlist_t)
672672

673+
if self.spin is not None:
674+
judge = tf.equal(natoms[0], natoms[1])
675+
self.diff_coord = tf.cond(
676+
judge,
677+
lambda: self.natoms_match(coord, natoms),
678+
lambda: self.natoms_not_match(coord, natoms, atype),
679+
)
680+
diff_coord_reshape = tf.reshape(self.diff_coord, [-1, natoms[1], 3])
681+
nlist_reshaped = tf.reshape(self.nlist, [-1, natoms[0], self.nnei])
682+
rj_gathered = tf.gather(
683+
diff_coord_reshape, tf.maximum(nlist_reshaped, 0), axis=1, batch_dims=1
684+
)
685+
ri_gathered = diff_coord_reshape[:, : natoms[0], tf.newaxis, :]
686+
rij_reshaped = tf.reshape(self.rij, [-1, natoms[0], self.nnei, 3])
687+
rij_update = rij_reshaped - rj_gathered + ri_gathered
688+
self.rij = tf.reshape(rij_update, tf.shape(self.rij))
689+
673690
# only used when tensorboard was set as true
674691
tf.summary.histogram("descrpt", self.descrpt)
675692
tf.summary.histogram("rij", self.rij)
@@ -1374,6 +1391,62 @@ def init_variables(
13741391
)
13751392
)
13761393

1394+
def natoms_match(self, coord, natoms):
1395+
natoms_index = tf.concat([[0], tf.cumsum(natoms[2:])], axis=0)
1396+
diff_coord_loc = []
1397+
for i in range(self.ntypes):
1398+
if i + self.ntypes_spin >= self.ntypes:
1399+
diff_coord_loc.append(
1400+
tf.slice(
1401+
coord,
1402+
[0, natoms_index[i] * 3],
1403+
[-1, natoms[2 + i] * 3],
1404+
)
1405+
- tf.slice(
1406+
coord,
1407+
[0, natoms_index[i - len(self.spin.use_spin)] * 3],
1408+
[-1, natoms[2 + i - len(self.spin.use_spin)] * 3],
1409+
)
1410+
)
1411+
else:
1412+
diff_coord_loc.append(
1413+
tf.zeros([tf.shape(coord)[0], natoms[2 + i] * 3], dtype=coord.dtype)
1414+
)
1415+
diff_coord_loc = tf.concat(diff_coord_loc, axis=1)
1416+
return diff_coord_loc
1417+
1418+
def natoms_not_match(self, coord, natoms, atype):
1419+
diff_coord_loc = self.natoms_match(coord, natoms)
1420+
diff_coord_ghost = []
1421+
aatype = atype[0, :]
1422+
ghost_atype = aatype[natoms[0] :]
1423+
_, _, ghost_natoms = tf.unique_with_counts(ghost_atype)
1424+
ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0)
1425+
ghost_natoms_index += natoms[0]
1426+
for i in range(self.ntypes):
1427+
if i + self.ntypes_spin >= self.ntypes:
1428+
diff_coord_ghost.append(
1429+
tf.slice(
1430+
coord,
1431+
[0, ghost_natoms_index[i] * 3],
1432+
[-1, ghost_natoms[i] * 3],
1433+
)
1434+
- tf.slice(
1435+
coord,
1436+
[0, ghost_natoms_index[i - len(self.spin.use_spin)] * 3],
1437+
[-1, ghost_natoms[i - len(self.spin.use_spin)] * 3],
1438+
)
1439+
)
1440+
else:
1441+
diff_coord_ghost.append(
1442+
tf.zeros(
1443+
[tf.shape(coord)[0], ghost_natoms[i] * 3], dtype=coord.dtype
1444+
)
1445+
)
1446+
diff_coord_ghost = tf.concat(diff_coord_ghost, axis=1)
1447+
diff_coord = tf.concat([diff_coord_loc, diff_coord_ghost], axis=1)
1448+
return diff_coord
1449+
13771450
@property
13781451
def explicit_ntypes(self) -> bool:
13791452
"""Explicit ntypes with type embedding."""

source/api_cc/src/DeepSpinTF.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,11 +809,11 @@ void DeepSpinTF::compute(ENERGYVTYPE& dener,
809809
datom_energy_.resize(static_cast<size_t>(nframes) * nall);
810810
datom_virial_.resize(static_cast<size_t>(nframes) * nall * 9);
811811
for (int ii = 0; ii < nall; ++ii) {
812+
int new_idx = new_idx_map[ii];
812813
for (int dd = 0; dd < 3; ++dd) {
813-
int new_idx = new_idx_map[ii];
814814
dforce_[3 * ii + dd] = dforce_tmp[3 * new_idx + dd];
815815
datom_energy_[ii] = datom_energy_tmp[new_idx];
816-
datom_virial_[ii] = datom_virial_tmp[new_idx];
816+
817817
if (datype_[ii] < ntypes_spin && ii < nloc) {
818818
dforce_mag_[3 * ii + dd] = dforce_tmp[3 * (new_idx + nloc) + dd];
819819
} else if (datype_[ii] < ntypes_spin) {
@@ -822,6 +822,9 @@ void DeepSpinTF::compute(ENERGYVTYPE& dener,
822822
dforce_mag_[3 * ii + dd] = 0.0;
823823
}
824824
}
825+
for (int dd = 0; dd < 9; ++dd) {
826+
datom_virial_[ii * 9 + dd] = datom_virial_tmp[new_idx * 9 + dd];
827+
}
825828
}
826829
}
827830

source/tests/tf/test_model_spin.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_model_spin(self) -> None:
126126
)
127127

128128
out_ener = np.reshape(out_ener, [-1])
129+
out_virial = np.reshape(out_virial, [-1])
129130
natoms_real = np.sum(
130131
test_data["natoms_vec"][2 : 2 + len(spin_param["use_spin"])]
131132
)
@@ -425,14 +426,36 @@ def test_model_spin(self) -> None:
425426
-0.0007015535942944564,
426427
0.004459188855221506,
427428
]
429+
refv = [
430+
0.33691325723275595,
431+
0.024301747372056412,
432+
-0.06880806009046331,
433+
0.026792188153995887,
434+
0.3007953744219118,
435+
-0.051612531097108075,
436+
-0.07274496651648972,
437+
-0.05156414798680478,
438+
0.37692901508963417,
439+
0.3205610686355494,
440+
0.013102936385366228,
441+
-0.04419007538301404,
442+
0.014186144311082909,
443+
0.31565216176483,
444+
-0.058829665227551474,
445+
-0.04759429793837308,
446+
-0.05932221615318792,
447+
0.39040431257661773,
448+
]
428449
refe = np.reshape(refe, [-1])
429450
refr = np.reshape(refr, [-1])
430451
refm = np.reshape(refm, [-1])
452+
refv = np.reshape(refv, [-1])
431453

432454
places = 10
433455
np.testing.assert_almost_equal(out_ener, refe, places)
434456
np.testing.assert_almost_equal(force_real, refr, places)
435457
np.testing.assert_almost_equal(force_mag, refm, places)
458+
np.testing.assert_almost_equal(out_virial, refv, places)
436459

437460

438461
if __name__ == "__main__":

0 commit comments

Comments
 (0)