Skip to content

Commit 8fd9565

Browse files
committed
feat(pt): support spin virial
1 parent 9af197c commit 8fd9565

11 files changed

Lines changed: 126 additions & 40 deletions

File tree

deepmd/pt/loss/ener_spin.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,22 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
268268
rmse_ae.detach(), find_atom_ener
269269
)
270270

271+
if self.has_v and "virial" in model_pred and "virial" in label:
272+
find_virial = label.get("find_virial", 0.0)
273+
pref_v = pref_v * find_virial
274+
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
275+
l2_virial_loss = torch.mean(torch.square(diff_v))
276+
if not self.inference:
277+
more_loss["l2_virial_loss"] = self.display_if_exist(
278+
l2_virial_loss.detach(), find_virial
279+
)
280+
loss += atom_norm * (pref_v * l2_virial_loss)
281+
rmse_v = l2_virial_loss.sqrt() * atom_norm
282+
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
283+
if mae:
284+
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
285+
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
286+
271287
if not self.inference:
272288
more_loss["rmse"] = torch.sqrt(loss.detach())
273289
return model_pred, loss, more_loss

deepmd/pt/model/model/make_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def forward_common(
135135
fparam: Optional[torch.Tensor] = None,
136136
aparam: Optional[torch.Tensor] = None,
137137
do_atomic_virial: bool = False,
138+
coord_corr_for_virial: Optional[torch.Tensor] = None,
138139
) -> dict[str, torch.Tensor]:
139140
"""Return model prediction.
140141
@@ -153,6 +154,9 @@ def forward_common(
153154
atomic parameter. nf x nloc x nda
154155
do_atomic_virial
155156
If calculate the atomic virial.
157+
coord_corr_for_virial
158+
The coordinates correction of the atoms for virial.
159+
shape: nf x (nloc x 3)
156160
157161
Returns
158162
-------
@@ -180,6 +184,14 @@ def forward_common(
180184
mixed_types=True,
181185
box=bb,
182186
)
187+
if coord_corr_for_virial is not None:
188+
coord_corr_for_virial = coord_corr_for_virial.to(cc.dtype)
189+
extended_coord_corr = torch.gather(
190+
coord_corr_for_virial, 1, mapping.unsqueeze(-1).expand(-1, -1, 3)
191+
)
192+
else:
193+
extended_coord_corr = None
194+
183195
model_predict_lower = self.forward_common_lower(
184196
extended_coord,
185197
extended_atype,
@@ -188,6 +200,7 @@ def forward_common(
188200
do_atomic_virial=do_atomic_virial,
189201
fparam=fp,
190202
aparam=ap,
203+
extended_coord_corr=extended_coord_corr,
191204
)
192205
model_predict = communicate_extended_output(
193206
model_predict_lower,
@@ -242,6 +255,7 @@ def forward_common_lower(
242255
do_atomic_virial: bool = False,
243256
comm_dict: Optional[dict[str, torch.Tensor]] = None,
244257
extra_nlist_sort: bool = False,
258+
extended_coord_corr: Optional[torch.Tensor] = None,
245259
):
246260
"""Return model prediction. Lower interface that takes
247261
extended atomic coordinates and types, nlist, and mapping
@@ -268,6 +282,8 @@ def forward_common_lower(
268282
The data needed for communication for parallel inference.
269283
extra_nlist_sort
270284
whether to forcibly sort the nlist.
285+
extended_coord_corr
286+
coordinates correction for virial in extended region. nf x (nall x 3)
271287
272288
Returns
273289
-------
@@ -299,6 +315,7 @@ def forward_common_lower(
299315
cc_ext,
300316
do_atomic_virial=do_atomic_virial,
301317
create_graph=self.training,
318+
extended_coord_corr=extended_coord_corr,
302319
)
303320
model_predict = self.output_type_cast(model_predict, input_prec)
304321
return model_predict

deepmd/pt/model/model/spin_model.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,14 @@ def process_spin_input(self, coord, atype, spin):
5454
coord = coord.reshape(nframes, nloc, 3)
5555
spin = spin.reshape(nframes, nloc, 3)
5656
atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1)
57-
virtual_coord = coord + spin * (self.virtual_scale_mask.to(atype.device))[
58-
atype
59-
].reshape([nframes, nloc, 1])
57+
spin_dist = spin * (self.virtual_scale_mask.to(atype.device))[atype].reshape(
58+
[nframes, nloc, 1]
59+
)
60+
virtual_coord = coord + spin_dist
6061
coord_spin = torch.concat([coord, virtual_coord], dim=-2)
61-
return coord_spin, atype_spin
62+
# for spin virial corr
63+
coord_corr = torch.concat([torch.zeros_like(coord), -spin_dist], dim=-2)
64+
return coord_spin, atype_spin, coord_corr
6265

6366
def process_spin_input_lower(
6467
self,
@@ -78,13 +81,18 @@ def process_spin_input_lower(
7881
"""
7982
nframes, nall = extended_coord.shape[:2]
8083
nloc = nlist.shape[1]
81-
virtual_extended_coord = extended_coord + extended_spin * (
84+
extended_spin_dist = extended_spin * (
8285
self.virtual_scale_mask.to(extended_atype.device)
8386
)[extended_atype].reshape([nframes, nall, 1])
87+
virtual_extended_coord = extended_coord + extended_spin_dist
8488
virtual_extended_atype = extended_atype + self.ntypes_real
8589
extended_coord_updated = concat_switch_virtual(
8690
extended_coord, virtual_extended_coord, nloc
8791
)
92+
# for spin virial corr
93+
extended_coord_corr = concat_switch_virtual(
94+
torch.zeros_like(extended_coord), -extended_spin_dist, nloc
95+
)
8896
extended_atype_updated = concat_switch_virtual(
8997
extended_atype, virtual_extended_atype, nloc
9098
)
@@ -100,6 +108,7 @@ def process_spin_input_lower(
100108
extended_atype_updated,
101109
nlist_updated,
102110
mapping_updated,
111+
extended_coord_corr,
103112
)
104113

105114
def process_spin_output(
@@ -367,7 +376,7 @@ def spin_sampled_func():
367376
sampled = sampled_func()
368377
spin_sampled = []
369378
for sys in sampled:
370-
coord_updated, atype_updated = self.process_spin_input(
379+
coord_updated, atype_updated, _ = self.process_spin_input(
371380
sys["coord"], sys["atype"], sys["spin"]
372381
)
373382
tmp_dict = {
@@ -398,7 +407,9 @@ def forward_common(
398407
do_atomic_virial: bool = False,
399408
) -> dict[str, torch.Tensor]:
400409
nframes, nloc = atype.shape
401-
coord_updated, atype_updated = self.process_spin_input(coord, atype, spin)
410+
coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input(
411+
coord, atype, spin
412+
)
402413
if aparam is not None:
403414
aparam = self.expand_aparam(aparam, nloc * 2)
404415
model_ret = self.backbone_model.forward_common(
@@ -408,6 +419,7 @@ def forward_common(
408419
fparam=fparam,
409420
aparam=aparam,
410421
do_atomic_virial=do_atomic_virial,
422+
coord_corr_for_virial=coord_corr_for_virial,
411423
)
412424
model_output_type = self.backbone_model.model_output_type()
413425
if "mask" in model_output_type:
@@ -454,6 +466,7 @@ def forward_common_lower(
454466
extended_atype_updated,
455467
nlist_updated,
456468
mapping_updated,
469+
extended_coord_corr_for_virial,
457470
) = self.process_spin_input_lower(
458471
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
459472
)
@@ -469,6 +482,7 @@ def forward_common_lower(
469482
do_atomic_virial=do_atomic_virial,
470483
comm_dict=comm_dict,
471484
extra_nlist_sort=extra_nlist_sort,
485+
extended_coord_corr=extended_coord_corr_for_virial,
472486
)
473487
model_output_type = self.backbone_model.model_output_type()
474488
if "mask" in model_output_type:
@@ -541,6 +555,11 @@ def translated_output_def(self):
541555
output_def["force"].squeeze(-2)
542556
output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"])
543557
output_def["force_mag"].squeeze(-2)
558+
if self.do_grad_c("energy"):
559+
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
560+
output_def["virial"].squeeze(-2)
561+
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
562+
output_def["atom_virial"].squeeze(-3)
544563
return output_def
545564

546565
def forward(
@@ -569,7 +588,10 @@ def forward(
569588
if self.backbone_model.do_grad_r("energy"):
570589
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
571590
model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2)
572-
# not support virial by far
591+
if self.backbone_model.do_grad_c("energy"):
592+
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
593+
if do_atomic_virial:
594+
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
573595
return model_predict
574596

575597
@torch.jit.export
@@ -606,5 +628,10 @@ def forward_lower(
606628
model_predict["extended_force_mag"] = model_ret[
607629
"energy_derv_r_mag"
608630
].squeeze(-2)
609-
# not support virial by far
631+
if self.backbone_model.do_grad_c("energy"):
632+
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
633+
if do_atomic_virial:
634+
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
635+
-3
636+
)
610637
return model_predict

deepmd/pt/model/model/transform_output.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def fit_output_to_model_output(
156156
coord_ext: torch.Tensor,
157157
do_atomic_virial: bool = False,
158158
create_graph: bool = True,
159+
extended_coord_corr: Optional[torch.Tensor] = None,
159160
) -> dict[str, torch.Tensor]:
160161
"""Transform the output of the fitting network to
161162
the model output.
@@ -187,6 +188,12 @@ def fit_output_to_model_output(
187188
model_ret[kk_derv_r] = dr
188189
if vdef.c_differentiable:
189190
assert dc is not None
191+
if extended_coord_corr is not None:
192+
dc_corr = (
193+
dr.squeeze(-2).unsqueeze(-1)
194+
@ extended_coord_corr.unsqueeze(-2)
195+
).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005
196+
dc = dc + dc_corr
190197
model_ret[kk_derv_c] = dc
191198
model_ret[kk_derv_c + "_redu"] = torch.sum(
192199
model_ret[kk_derv_c].to(redu_prec), dim=1

source/api_c/include/deepmd.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2602,9 +2602,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
26022602
for (int j = 0; j < natoms * 3; j++) {
26032603
force_mag[i][j] = force_mag_flat[i * natoms * 3 + j];
26042604
}
2605-
// for (int j = 0; j < 9; j++) {
2606-
// virial[i][j] = virial_flat[i * 9 + j];
2607-
// }
2605+
for (int j = 0; j < 9; j++) {
2606+
virial[i][j] = virial_flat[i * 9 + j];
2607+
}
26082608
}
26092609
};
26102610
/**
@@ -2705,9 +2705,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
27052705
for (int j = 0; j < natoms * 3; j++) {
27062706
force_mag[i][j] = force_mag_flat[i * natoms * 3 + j];
27072707
}
2708-
// for (int j = 0; j < 9; j++) {
2709-
// virial[i][j] = virial_flat[i * 9 + j];
2710-
// }
2708+
for (int j = 0; j < 9; j++) {
2709+
virial[i][j] = virial_flat[i * 9 + j];
2710+
}
27112711
for (int j = 0; j < natoms; j++) {
27122712
atom_energy[i][j] = atom_energy_flat[i * natoms + j];
27132713
}

source/api_c/src/c_api.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -862,11 +862,11 @@ void DP_DeepSpinModelDeviCompute_variant(DP_DeepSpinModelDevi* dp,
862862
flatten_vector(fm_flat, fm);
863863
std::copy(fm_flat.begin(), fm_flat.end(), force_mag);
864864
}
865-
// if (virial) {
866-
// std::vector<VALUETYPE> v_flat;
867-
// flatten_vector(v_flat, v);
868-
// std::copy(v_flat.begin(), v_flat.end(), virial);
869-
// }
865+
if (virial) {
866+
std::vector<VALUETYPE> v_flat;
867+
flatten_vector(v_flat, v);
868+
std::copy(v_flat.begin(), v_flat.end(), virial);
869+
}
870870
if (atomic_energy) {
871871
std::vector<VALUETYPE> ae_flat;
872872
flatten_vector(ae_flat, ae);

source/api_cc/src/DeepSpinPT.cc

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
251251
c10::IValue energy_ = outputs.at("energy");
252252
c10::IValue force_ = outputs.at("extended_force");
253253
c10::IValue force_mag_ = outputs.at("extended_force_mag");
254-
// spin model not suported yet
255-
// c10::IValue virial_ = outputs.at("virial");
254+
c10::IValue virial_ = outputs.at("virial");
256255
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
257256
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
258257
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
@@ -267,11 +266,11 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
267266
dforce_mag.assign(
268267
cpu_force_mag_.data_ptr<VALUETYPE>(),
269268
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
270-
// spin model not suported yet
271-
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
272-
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
273-
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
274-
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
269+
270+
torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
271+
torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
272+
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
273+
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
275274

276275
// bkw map
277276
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
@@ -415,8 +414,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
415414
c10::IValue energy_ = outputs.at("energy");
416415
c10::IValue force_ = outputs.at("force");
417416
c10::IValue force_mag_ = outputs.at("force_mag");
418-
// spin model not suported yet
419-
// c10::IValue virial_ = outputs.at("virial");
417+
c10::IValue virial_ = outputs.at("virial");
420418
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
421419
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
422420
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
@@ -431,11 +429,10 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
431429
force_mag.assign(
432430
cpu_force_mag_.data_ptr<VALUETYPE>(),
433431
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
434-
// spin model not suported yet
435-
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
436-
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
437-
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
438-
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
432+
torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
433+
torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
434+
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
435+
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
439436
if (atomic) {
440437
// c10::IValue atom_virial_ = outputs.at("atom_virial");
441438
c10::IValue atom_energy_ = outputs.at("atom_energy");

source/tests/pt/model/test_autodiff.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,17 @@ def test(
141141
cell = (cell) + 5.0 * torch.eye(3, device="cpu")
142142
coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
143143
coord = torch.matmul(coord, cell)
144+
spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
144145
atype = torch.IntTensor([0, 0, 0, 1, 1])
145146
# assumes input to be numpy tensor
146147
coord = coord.numpy()
148+
spin = spin.numpy()
147149
cell = cell.numpy()
148-
test_keys = ["energy", "force", "virial"]
150+
test_spin = getattr(self, "test_spin", False)
151+
if not test_spin:
152+
test_keys = ["energy", "force", "virial"]
153+
else:
154+
test_keys = ["energy", "force", "force_mag", "virial"]
149155

150156
def np_infer(
151157
new_cell,
@@ -157,6 +163,7 @@ def np_infer(
157163
).unsqueeze(0),
158164
torch.tensor(new_cell, device="cpu").unsqueeze(0),
159165
atype,
166+
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
160167
)
161168
# detach
162169
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys}
@@ -251,3 +258,11 @@ def setUp(self) -> None:
251258
self.type_split = False
252259
self.test_spin = True
253260
self.model = get_model(model_params).to(env.DEVICE)
261+
262+
263+
class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest):
264+
def setUp(self) -> None:
265+
model_params = copy.deepcopy(model_spin)
266+
self.type_split = False
267+
self.test_spin = True
268+
self.model = get_model(model_params).to(env.DEVICE)

source/tests/pt/model/test_ener_spin_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_input_output_process(self) -> None:
115115
nframes, nloc = self.coord.shape[:2]
116116
self.real_ntypes = self.model.spin.get_ntypes_real()
117117
# 1. test forward input process
118-
coord_updated, atype_updated = self.model.process_spin_input(
118+
coord_updated, atype_updated, _ = self.model.process_spin_input(
119119
self.coord, self.atype, self.spin
120120
)
121121
# compare atypes of real and virtual atoms
@@ -174,6 +174,7 @@ def test_input_output_process(self) -> None:
174174
extended_atype_updated,
175175
nlist_updated,
176176
mapping_updated,
177+
_,
177178
) = self.model.process_spin_input_lower(
178179
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
179180
)

0 commit comments

Comments
 (0)