Skip to content

Commit 3a5c51c

Browse files
committed
feat: calculate hessian by part in the training loop
1 parent c7edd6f commit 3a5c51c

3 files changed

Lines changed: 47 additions & 20 deletions

File tree

deepmd/pt/loss/ener.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -555,29 +555,55 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
555555
)
556556
coef = learning_rate / self.starter_learning_rate
557557
pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef
558-
559-
if self.has_h and "hessian" in model_pred and "hessian" in label:
560-
find_hessian = label.get("find_hessian", 0.0)
561-
pref_h = pref_h * find_hessian
562-
diff_h = label["hessian"].reshape(
563-
-1,
564-
) - model_pred["hessian"].reshape(
565-
-1,
566-
)
567-
l2_hessian_loss = torch.mean(torch.square(diff_h))
558+
# max number of atoms in a batch
559+
HESSIAN_BATCH_SIZE = 48 # FIXME: make it configurable
560+
if self.has_h:
561+
find_hessian = label["find_hessian"]
562+
pref_h: float = pref_h * find_hessian
563+
# Accumulate global sums for unbiased MAE/RMSE across tiles
564+
total_abs_err = torch.zeros((), device=env.DEVICE)
565+
total_sse = torch.zeros((), device=env.DEVICE)
566+
total_count: int = 0
567+
# split hessian calculations into batches
568+
slices = list(range(0, natoms, HESSIAN_BATCH_SIZE))
569+
# add the last slice
570+
if slices[-1] != natoms:
571+
slices.append(natoms)
572+
573+
for i,j in zip(slices[:-1], slices[1:]):
574+
h_tile_pred = model._cal_e_hessian_block(
575+
model_pred["force"], input_dict["coord"], slice(i, j)
576+
) # assuming force is always calculated
577+
h_tile_label = label["hessian"].view(
578+
input_dict["coord"].shape[0], # nframes
579+
natoms * 3,
580+
natoms * 3,
581+
)[:, None, i * 3 : j * 3, :]
582+
h_tile_diff:torch.Tensor = h_tile_label - h_tile_pred
583+
h_tile_l2 = h_tile_diff.square().mean()
584+
h_tile_loss = h_tile_l2 * pref_h
585+
if not self.inference and not torch.isnan(h_tile_loss):
586+
# TODO: check if OOM happenes with retain_graph
587+
h_tile_loss.backward(retain_graph=True) # required!
588+
# Accumulate unbiased metrics (size-weighted across tiles)
589+
total_abs_err = total_abs_err + h_tile_diff.abs().sum().detach()
590+
total_sse = total_sse + h_tile_diff.square().sum().detach()
591+
total_count += int(h_tile_diff.numel())
592+
593+
rmse_h = torch.sqrt(total_sse / total_count)
594+
more_loss["rmse_h"] = self.display_if_exist(rmse_h, find_hessian)
595+
if mae:
596+
mae_h = total_abs_err / total_count
597+
more_loss["mae_h"] = self.display_if_exist(mae_h, find_hessian)
568598
if not self.inference:
569599
more_loss["l2_hessian_loss"] = self.display_if_exist(
570-
l2_hessian_loss.detach(), find_hessian
600+
total_sse / total_count, find_hessian
571601
)
572-
loss += pref_h * l2_hessian_loss
573-
rmse_h = l2_hessian_loss.sqrt()
574-
more_loss["rmse_h"] = self.display_if_exist(rmse_h.detach(), find_hessian)
575-
if mae:
576-
mae_h = torch.mean(torch.abs(diff_h))
577-
more_loss["mae_h"] = self.display_if_exist(mae_h.detach(), find_hessian)
578602

579603
if not self.inference:
580-
more_loss["rmse"] = torch.sqrt(loss.detach())
604+
more_loss["rmse"] = torch.sqrt(
605+
loss.detach() + pref_h * more_loss.get("l2_hessian_loss",0)
606+
)
581607
return model_pred, loss, more_loss
582608

583609
@property

deepmd/pt/model/model/ener_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def forward(
123123
model_predict["force"] = model_ret["dforce"]
124124
if "mask" in model_ret:
125125
model_predict["mask"] = model_ret["mask"]
126-
if self._hessian_enabled:
126+
if self._hessian_enabled and "energy_derv_r_derv_r" in model_ret:
127127
model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2)
128128
else:
129129
model_predict = model_ret

deepmd/pt/model/model/make_hessian_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def forward_common(
106106
)
107107

108108
if any(hess_yes):
109+
return ret # HACK: Calculate hessian in the training loop
109110
if (
110111
vdef["energy"].r_hessian
111112
and sum(hess_yes) == 1
@@ -145,7 +146,7 @@ def _cal_e_hessian_block(
145146
.view(nslice * 3, nslice, 3)
146147
.unsqueeze(1) # (nslice * 3, 1, nslice, 3)
147148
.expand(-1, nf, -1, -1), # (nslice * 3, nf, nslice, 3)
148-
create_graph=self.training,
149+
create_graph=True, #self.training,
149150
retain_graph=True,
150151
is_grads_batched=True,
151152
)[0] # (nslice * 3, nf, nloc, 3)

0 commit comments

Comments
 (0)