Skip to content

Commit 7234d2a

Browse files
committed
rebuild compute graph to save memory
1 parent 146e480 commit 7234d2a

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

deepmd/pt/loss/ener.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
556556
coef = learning_rate / self.starter_learning_rate
557557
pref_h = self.limit_pref_h + (self.start_pref_h - self.limit_pref_h) * coef
558558
# max number of atoms in a batch
559-
HESSIAN_BATCH_SIZE = 48 # FIXME: make it configurable
559+
HESSIAN_BATCH_SIZE = 24 # FIXME: make it configurable
560560
if self.has_h:
561561
find_hessian = label["find_hessian"]
562562
pref_h: float = pref_h * find_hessian
@@ -569,7 +569,6 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
569569
# add the last slice
570570
if slices[-1] != natoms:
571571
slices.append(natoms)
572-
573572
for i,j in zip(slices[:-1], slices[1:]):
574573
h_tile_pred = model._cal_e_hessian_block(
575574
model_pred["force"], input_dict["coord"], slice(i, j)
@@ -582,14 +581,18 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
582581
h_tile_diff:torch.Tensor = h_tile_label - h_tile_pred
583582
h_tile_l2 = h_tile_diff.square().mean()
584583
h_tile_loss = h_tile_l2 * pref_h
585-
if not self.inference and not torch.isnan(h_tile_loss):
586-
# TODO: check if OOM happenes with retain_graph
587-
h_tile_loss.backward(retain_graph=True) # required!
584+
if not self.inference:
585+
h_tile_loss.backward(retain_graph=False)
586+
model_pred, loss, more_loss = super().forward(
587+
input_dict, model, label, natoms, learning_rate, mae=mae
588+
) # rebuilding the calculation graph
588589
# Accumulate unbiased metrics (size-weighted across tiles)
589590
total_abs_err = total_abs_err + h_tile_diff.abs().sum().detach()
590591
total_sse = total_sse + h_tile_diff.square().sum().detach()
591592
total_count += int(h_tile_diff.numel())
592593

594+
# Note: Observed metrics in training are better than expected,
595+
# for the hessians calculated later in the loop are inferred with the updated weights.
593596
rmse_h = torch.sqrt(total_sse / total_count)
594597
more_loss["rmse_h"] = self.display_if_exist(rmse_h, find_hessian)
595598
if mae:

0 commit comments

Comments
 (0)