@@ -556,7 +556,7 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
556556 coef = learning_rate / self .starter_learning_rate
557557 pref_h = self .limit_pref_h + (self .start_pref_h - self .limit_pref_h ) * coef
558558 # max number of atoms in a batch
559- HESSIAN_BATCH_SIZE = 48 # FIXME: make it configurable
559+ HESSIAN_BATCH_SIZE = 24 # FIXME: make it configurable
560560 if self .has_h :
561561 find_hessian = label ["find_hessian" ]
562562 pref_h : float = pref_h * find_hessian
@@ -569,7 +569,6 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
569569 # add the last slice
570570 if slices [- 1 ] != natoms :
571571 slices .append (natoms )
572-
573572 for i ,j in zip (slices [:- 1 ], slices [1 :]):
574573 h_tile_pred = model ._cal_e_hessian_block (
575574 model_pred ["force" ], input_dict ["coord" ], slice (i , j )
@@ -582,14 +581,18 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
582581 h_tile_diff :torch .Tensor = h_tile_label - h_tile_pred
583582 h_tile_l2 = h_tile_diff .square ().mean ()
584583 h_tile_loss = h_tile_l2 * pref_h
585- if not self .inference and not torch .isnan (h_tile_loss ):
586- # TODO: check if OOM happenes with retain_graph
587- h_tile_loss .backward (retain_graph = True ) # required!
584+ if not self .inference :
585+ h_tile_loss .backward (retain_graph = False )
586+ model_pred , loss , more_loss = super ().forward (
587+ input_dict , model , label , natoms , learning_rate , mae = mae
588+ ) # rebuilding the calculation graph
588589 # Accumulate unbiased metrics (size-weighted across tiles)
589590 total_abs_err = total_abs_err + h_tile_diff .abs ().sum ().detach ()
590591 total_sse = total_sse + h_tile_diff .square ().sum ().detach ()
591592 total_count += int (h_tile_diff .numel ())
592593
594+ # Note: Observed metrics in training are better than expected,
595+ # for the hessians calculated later in the loop are inferred with the updated weights.
593596 rmse_h = torch .sqrt (total_sse / total_count )
594597 more_loss ["rmse_h" ] = self .display_if_exist (rmse_h , find_hessian )
595598 if mae :
0 commit comments