Skip to content

Commit 33ff63d

Browse files
committed
Update virial and add UT.
1 parent 9c2d3ed commit 33ff63d

File tree

3 files changed

+101
-2
lines changed

3 files changed

+101
-2
lines changed

deepmd/tf/descriptor/se_a.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,23 @@ def build(
663663
atype_t = tf.concat([[self.ntypes], tf.reshape(self.atype, [-1])], axis=0)
664664
self.nei_type_vec = tf.nn.embedding_lookup(atype_t, nlist_t)
665665

666+
if self.spin is not None:
667+
judge = tf.equal(natoms[0], natoms[1])
668+
self.diff_coord = tf.cond(
669+
judge,
670+
lambda: self.natoms_match(coord, natoms),
671+
lambda: self.natoms_not_match(coord, natoms, atype),
672+
)
673+
diff_coord_reshape = tf.reshape(self.diff_coord, [-1, natoms[1], 3])
674+
nlist_reshaped = tf.reshape(self.nlist, [-1, natoms[0], self.nnei])
675+
rj_gathered = tf.gather(
676+
diff_coord_reshape, tf.maximum(nlist_reshaped, 0), axis=1, batch_dims=1
677+
)
678+
ri_gathered = diff_coord_reshape[:, : natoms[0], tf.newaxis, :]
679+
rij_reshaped = tf.reshape(self.rij, [-1, natoms[0], self.nnei, 3])
680+
rij_update = rij_reshaped - rj_gathered + ri_gathered
681+
self.rij = tf.reshape(rij_update, tf.shape(self.rij))
682+
666683
# only used when tensorboard was set as true
667684
tf.summary.histogram("descrpt", self.descrpt)
668685
tf.summary.histogram("rij", self.rij)
@@ -1355,6 +1372,62 @@ def init_variables(
13551372
)
13561373
)
13571374

1375+
def natoms_match(self, coord, natoms):
1376+
natoms_index = tf.concat([[0], tf.cumsum(natoms[2:])], axis=0)
1377+
diff_coord_loc = []
1378+
for i in range(self.ntypes):
1379+
if i + self.ntypes_spin >= self.ntypes:
1380+
diff_coord_loc.append(
1381+
tf.slice(
1382+
coord,
1383+
[0, natoms_index[i] * 3],
1384+
[-1, natoms[2 + i] * 3],
1385+
)
1386+
- tf.slice(
1387+
coord,
1388+
[0, natoms_index[i - len(self.spin.use_spin)] * 3],
1389+
[-1, natoms[2 + i - len(self.spin.use_spin)] * 3],
1390+
)
1391+
)
1392+
else:
1393+
diff_coord_loc.append(
1394+
tf.zeros([tf.shape(coord)[0], natoms[2 + i] * 3], dtype=coord.dtype)
1395+
)
1396+
diff_coord_loc = tf.concat(diff_coord_loc, axis=1)
1397+
return diff_coord_loc
1398+
1399+
def natoms_not_match(self, coord, natoms, atype):
1400+
diff_coord_loc = self.natoms_match(coord, natoms)
1401+
diff_coord_ghost = []
1402+
aatype = atype[0, :]
1403+
ghost_atype = aatype[natoms[0] :]
1404+
_, _, ghost_natoms = tf.unique_with_counts(ghost_atype)
1405+
ghost_natoms_index = tf.concat([[0], tf.cumsum(ghost_natoms)], axis=0)
1406+
ghost_natoms_index += natoms[0]
1407+
for i in range(self.ntypes):
1408+
if i + self.ntypes_spin >= self.ntypes:
1409+
diff_coord_ghost.append(
1410+
tf.slice(
1411+
coord,
1412+
[0, ghost_natoms_index[i] * 3],
1413+
[-1, ghost_natoms[i] * 3],
1414+
)
1415+
- tf.slice(
1416+
coord,
1417+
[0, ghost_natoms_index[i - len(self.spin.use_spin)] * 3],
1418+
[-1, ghost_natoms[i - len(self.spin.use_spin)] * 3],
1419+
)
1420+
)
1421+
else:
1422+
diff_coord_ghost.append(
1423+
tf.zeros(
1424+
[tf.shape(coord)[0], ghost_natoms[i] * 3], dtype=coord.dtype
1425+
)
1426+
)
1427+
diff_coord_ghost = tf.concat(diff_coord_ghost, axis=1)
1428+
diff_coord = tf.concat([diff_coord_loc, diff_coord_ghost], axis=1)
1429+
return diff_coord
1430+
13581431
@property
13591432
def explicit_ntypes(self) -> bool:
13601433
"""Explicit ntypes with type embedding."""

source/api_cc/src/DeepSpinTF.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,11 +809,11 @@ void DeepSpinTF::compute(ENERGYVTYPE& dener,
809809
datom_energy_.resize(static_cast<size_t>(nframes) * nall);
810810
datom_virial_.resize(static_cast<size_t>(nframes) * nall * 9);
811811
for (int ii = 0; ii < nall; ++ii) {
812+
int new_idx = new_idx_map[ii];
812813
for (int dd = 0; dd < 3; ++dd) {
813-
int new_idx = new_idx_map[ii];
814814
dforce_[3 * ii + dd] = dforce_tmp[3 * new_idx + dd];
815815
datom_energy_[ii] = datom_energy_tmp[new_idx];
816-
datom_virial_[ii] = datom_virial_tmp[new_idx];
816+
817817
if (datype_[ii] < ntypes_spin && ii < nloc) {
818818
dforce_mag_[3 * ii + dd] = dforce_tmp[3 * (new_idx + nloc) + dd];
819819
} else if (datype_[ii] < ntypes_spin) {
@@ -822,6 +822,9 @@ void DeepSpinTF::compute(ENERGYVTYPE& dener,
822822
dforce_mag_[3 * ii + dd] = 0.0;
823823
}
824824
}
825+
for (int dd = 0; dd < 9; ++dd) {
826+
datom_virial_[ii * 9 + dd] = datom_virial_tmp[new_idx * 9 + dd];
827+
}
825828
}
826829
}
827830

source/tests/tf/test_model_spin.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_model_spin(self) -> None:
126126
)
127127

128128
out_ener = np.reshape(out_ener, [-1])
129+
out_virial = np.reshape(out_virial, [-1])
129130
natoms_real = np.sum(
130131
test_data["natoms_vec"][2 : 2 + len(spin_param["use_spin"])]
131132
)
@@ -425,14 +426,36 @@ def test_model_spin(self) -> None:
425426
-0.0007015535942944564,
426427
0.004459188855221506,
427428
]
429+
refv = [
430+
0.33691325723275595,
431+
0.024301747372056412,
432+
-0.06880806009046331,
433+
0.026792188153995887,
434+
0.3007953744219118,
435+
-0.051612531097108075,
436+
-0.07274496651648972,
437+
-0.05156414798680478,
438+
0.37692901508963417,
439+
0.3205610686355494,
440+
0.013102936385366228,
441+
-0.04419007538301404,
442+
0.014186144311082909,
443+
0.31565216176483,
444+
-0.058829665227551474,
445+
-0.04759429793837308,
446+
-0.05932221615318792,
447+
0.39040431257661773,
448+
]
428449
refe = np.reshape(refe, [-1])
429450
refr = np.reshape(refr, [-1])
430451
refm = np.reshape(refm, [-1])
452+
refv = np.reshape(refv, [-1])
431453

432454
places = 10
433455
np.testing.assert_almost_equal(out_ener, refe, places)
434456
np.testing.assert_almost_equal(force_real, refr, places)
435457
np.testing.assert_almost_equal(force_mag, refm, places)
458+
np.testing.assert_almost_equal(out_virial, refv, places)
436459

437460

438461
if __name__ == "__main__":

0 commit comments

Comments
 (0)