@@ -94,6 +94,7 @@ def __init__(
9494 of the residuals. Must be between ``0`` and ``1``.
9595 Default is ``0.999``.
9696 :raises: ValueError if `gamma` is not in the range (0, 1).
97+ :raises: ValueError if `eta` is not greater than 0.
9798 """
9899 super ().__init__ (
99100 model = model ,
@@ -114,24 +115,20 @@ def __init__(
114115 f"Invalid range: expected 0 < gamma < 1, but got { gamma } "
115116 )
116117
118+ # Validate range for eta
119+ if eta <= 0 :
120+ raise ValueError (f"Invalid range: expected eta > 0, but got { eta } " )
121+
117122 # Initialize parameters
118123 self .eta = eta
119124 self .gamma = gamma
120125
121126 # Initialize the weight of each point to 0
122- self .weights = {
123- cond : torch .zeros ((len (data ), 1 ), device = self .device )
124- for cond , data in self .problem .input_pts .items ()
125- }
126-
127- def on_train_start (self ):
128- """
129- Hook method called at the beginning of training.
130- """
131- device = self .trainer .strategy .root_device
132- for cond in self .weights :
133- self .weights [cond ] = self .weights [cond ].to (device )
134- return super ().on_train_start ()
127+ self .weights = {}
128+ for cond , data in self .problem .input_pts .items ():
129+ buffer_tensor = torch .zeros ((len (data ), 1 ), device = self .device )
130+ self .register_buffer (f"weight_{ cond } " , buffer_tensor )
131+ self .weights [cond ] = getattr (self , f"weight_{ cond } " )
135132
136133 def training_step (self , batch , batch_idx , ** kwargs ):
137134 """
0 commit comments