@@ -555,29 +555,55 @@ def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
555555 )
556556 coef = learning_rate / self .starter_learning_rate
557557 pref_h = self .limit_pref_h + (self .start_pref_h - self .limit_pref_h ) * coef
558-
559- if self .has_h and "hessian" in model_pred and "hessian" in label :
560- find_hessian = label .get ("find_hessian" , 0.0 )
561- pref_h = pref_h * find_hessian
562- diff_h = label ["hessian" ].reshape (
563- - 1 ,
564- ) - model_pred ["hessian" ].reshape (
565- - 1 ,
566- )
567- l2_hessian_loss = torch .mean (torch .square (diff_h ))
558+ # max number of atoms in a batch
559+ HESSIAN_BATCH_SIZE = 48 # FIXME: make it configurable
560+ if self .has_h :
561+ find_hessian = label ["find_hessian" ]
562+ pref_h : float = pref_h * find_hessian
563+ # Accumulate global sums for unbiased MAE/RMSE across tiles
564+ total_abs_err = torch .zeros ((), device = env .DEVICE )
565+ total_sse = torch .zeros ((), device = env .DEVICE )
566+ total_count : int = 0
567+ # split hessian calculations into batches
568+ slices = list (range (0 , natoms , HESSIAN_BATCH_SIZE ))
569+ # add the last slice
570+ if slices [- 1 ] != natoms :
571+ slices .append (natoms )
572+
573+ for i ,j in zip (slices [:- 1 ], slices [1 :]):
574+ h_tile_pred = model ._cal_e_hessian_block (
575+ model_pred ["force" ], input_dict ["coord" ], slice (i , j )
576+ ) # assuming force is always calculated
577+ h_tile_label = label ["hessian" ].view (
578+ input_dict ["coord" ].shape [0 ], # nframes
579+ natoms * 3 ,
580+ natoms * 3 ,
581+ )[:, None , i * 3 : j * 3 , :]
582+ h_tile_diff :torch .Tensor = h_tile_label - h_tile_pred
583+ h_tile_l2 = h_tile_diff .square ().mean ()
584+ h_tile_loss = h_tile_l2 * pref_h
585+ if not self .inference and not torch .isnan (h_tile_loss ):
586+ # TODO: check if OOM happenes with retain_graph
587+ h_tile_loss .backward (retain_graph = True ) # required!
588+ # Accumulate unbiased metrics (size-weighted across tiles)
589+ total_abs_err = total_abs_err + h_tile_diff .abs ().sum ().detach ()
590+ total_sse = total_sse + h_tile_diff .square ().sum ().detach ()
591+ total_count += int (h_tile_diff .numel ())
592+
593+ rmse_h = torch .sqrt (total_sse / total_count )
594+ more_loss ["rmse_h" ] = self .display_if_exist (rmse_h , find_hessian )
595+ if mae :
596+ mae_h = total_abs_err / total_count
597+ more_loss ["mae_h" ] = self .display_if_exist (mae_h , find_hessian )
568598 if not self .inference :
569599 more_loss ["l2_hessian_loss" ] = self .display_if_exist (
570- l2_hessian_loss . detach () , find_hessian
600+ total_sse / total_count , find_hessian
571601 )
572- loss += pref_h * l2_hessian_loss
573- rmse_h = l2_hessian_loss .sqrt ()
574- more_loss ["rmse_h" ] = self .display_if_exist (rmse_h .detach (), find_hessian )
575- if mae :
576- mae_h = torch .mean (torch .abs (diff_h ))
577- more_loss ["mae_h" ] = self .display_if_exist (mae_h .detach (), find_hessian )
578602
579603 if not self .inference :
580- more_loss ["rmse" ] = torch .sqrt (loss .detach ())
604+ more_loss ["rmse" ] = torch .sqrt (
605+ loss .detach () + pref_h * more_loss .get ("l2_hessian_loss" ,0 )
606+ )
581607 return model_pred , loss , more_loss
582608
583609 @property
0 commit comments