11# SPDX-License-Identifier: LGPL-3.0-or-later
2+ from typing import (
3+ Any ,
4+ )
25
36import paddle
47import paddle .nn .functional as F
2023)
2124
2225
23- def custom_huber_loss (predictions , targets , delta = 1.0 ):
26+ def custom_huber_loss (
27+ predictions : paddle .Tensor , targets : paddle .Tensor , delta : float = 1.0
28+ ) -> paddle .Tensor :
2429 error = targets - predictions
2530 abs_error = paddle .abs (error )
2631 quadratic_loss = 0.5 * paddle .pow (error , 2 )
@@ -32,13 +37,13 @@ def custom_huber_loss(predictions, targets, delta=1.0):
3237class EnergyStdLoss (TaskLoss ):
3338 def __init__ (
3439 self ,
35- starter_learning_rate = 1.0 ,
36- start_pref_e = 0.0 ,
37- limit_pref_e = 0.0 ,
38- start_pref_f = 0.0 ,
39- limit_pref_f = 0.0 ,
40- start_pref_v = 0.0 ,
41- limit_pref_v = 0.0 ,
40+ starter_learning_rate : float = 1.0 ,
41+ start_pref_e : float = 0.0 ,
42+ limit_pref_e : float = 0.0 ,
43+ start_pref_f : float = 0.0 ,
44+ limit_pref_f : float = 0.0 ,
45+ start_pref_v : float = 0.0 ,
46+ limit_pref_v : float = 0.0 ,
4247 start_pref_ae : float = 0.0 ,
4348 limit_pref_ae : float = 0.0 ,
4449 start_pref_pf : float = 0.0 ,
@@ -49,10 +54,10 @@ def __init__(
4954 limit_pref_gf : float = 0.0 ,
5055 numb_generalized_coord : int = 0 ,
5156 use_l1_all : bool = False ,
52- inference = False ,
53- use_huber = False ,
54- huber_delta = 0.01 ,
55- ** kwargs ,
57+ inference : bool = False ,
58+ use_huber : bool = False ,
59+ huber_delta : float = 0.01 ,
60+ ** kwargs : Any ,
5661 ) -> None :
5762 r"""Construct a layer to compute loss on energy, force and virial.
5863
@@ -146,7 +151,15 @@ def __init__(
146151 "Huber loss is not implemented for force with atom_pref, generalized force and relative force. "
147152 )
148153
149- def forward (self , input_dict , model , label , natoms , learning_rate , mae = False ):
154+ def forward (
155+ self ,
156+ input_dict : dict [str , paddle .Tensor ],
157+ model : paddle .nn .Layer ,
158+ label : dict [str , paddle .Tensor ],
159+ natoms : int ,
160+ learning_rate : float ,
161+ mae : bool = False ,
162+ ) -> tuple [dict [str , paddle .Tensor ], paddle .Tensor , dict [str , paddle .Tensor ]]:
150163 """Return loss on energy and force.
151164
152165 Parameters
@@ -535,10 +548,10 @@ def deserialize(cls, data: dict) -> "TaskLoss":
535548class EnergyHessianStdLoss (EnergyStdLoss ):
536549 def __init__ (
537550 self ,
538- start_pref_h = 0.0 ,
539- limit_pref_h = 0.0 ,
540- ** kwargs ,
541- ):
551+ start_pref_h : float = 0.0 ,
552+ limit_pref_h : float = 0.0 ,
553+ ** kwargs : Any ,
554+ ) -> None :
542555 r"""Enable the layer to compute loss on hessian.
543556
544557 Parameters
@@ -556,7 +569,15 @@ def __init__(
556569 self .start_pref_h = start_pref_h
557570 self .limit_pref_h = limit_pref_h
558571
559- def forward (self , input_dict , model , label , natoms , learning_rate , mae = False ):
572+ def forward (
573+ self ,
574+ input_dict : dict [str , paddle .Tensor ],
575+ model : paddle .nn .Module ,
576+ label : dict [str , paddle .Tensor ],
577+ natoms : int ,
578+ learning_rate : float ,
579+ mae : bool = False ,
580+ ) -> tuple [dict [str , paddle .Tensor ], paddle .Tensor , dict [str , paddle .Tensor ]]:
560581 model_pred , loss , more_loss = super ().forward (
561582 input_dict , model , label , natoms , learning_rate , mae = mae
562583 )
0 commit comments