@@ -15,15 +15,20 @@ class Weights(torch.nn.Module):
1515 :class:`SelfAdaptivePINN` solver.
1616 """
1717
18- def __init__ (self , func ):
18+ def __init__ (self , func , num_points ):
1919 """
2020 Initialization of the :class:`Weights` class.
2121
2222 :param torch.nn.Module func: the mask model.
23+ :param int num_points: the number of input points.
2324 """
2425 super ().__init__ ()
26+
27+ # Check consistency
2528 check_consistency (func , torch .nn .Module )
26- self .sa_weights = torch .nn .Parameter (torch .Tensor ())
29+
30+ # Initialize the weights as a learnable parameter
31+ self .sa_weights = torch .nn .Parameter (torch .zeros (num_points , 1 ))
2732 self .func = func
2833
2934 def forward (self ):
@@ -140,134 +145,111 @@ def __init__(
140145 If ``None``, the :class:`torch.nn.MSELoss` loss is used.
141146 Default is `None`.
142147 """
143- # check consistency weitghs_function
148+ # Check consistency
144149 check_consistency (weight_function , torch .nn .Module )
145150
146- # create models for weights
147- weights_dict = {}
148- for condition_name in problem .conditions :
149- weights_dict [ condition_name ] = Weights (weight_function )
150- weights_dict = torch .nn .ModuleDict (weights_dict )
151+ # Define a ModuleDict for the weights
152+ weights = {}
153+ for cond , data in problem .input_pts . items () :
154+ weights [ cond ] = Weights (func = weight_function , num_points = len ( data ) )
155+ weights = torch .nn .ModuleDict (weights )
151156
152157 super ().__init__ (
153- models = [model , weights_dict ],
158+ models = [model , weights ],
154159 problem = problem ,
155160 optimizers = [optimizer_model , optimizer_weights ],
156161 schedulers = [scheduler_model , scheduler_weights ],
157162 weighting = weighting ,
158163 loss = loss ,
159164 )
160165
161- self . _vectorial_loss = deepcopy ( self . loss )
162- self ._vectorial_loss . reduction = "none"
166+ # Extract the reduction method from the loss function
167+ self ._reduction = self . _loss_fn . reduction
163168
164- def forward (self , x ):
165- """
166- Forward pass.
169+ # Set the loss function to return non-aggregated losses
170+ self ._loss_fn = type (self ._loss_fn )(reduction = "none" )
167171
168- :param LabelTensor x: Input tensor.
169- :return: The output of the neural network.
170- :rtype: LabelTensor
172+ def training_step (self , batch , batch_idx , ** kwargs ):
171173 """
172- return self .model (x )
173-
174- def training_step (self , batch ):
175- """
176- Solver training step, overridden to perform manual optimization.
174+ Solver training step. It computes the optimization cycle and aggregates
175+ the losses using the ``weighting`` attribute.
177176
178177 :param list[tuple[str, dict]] batch: A batch of data. Each element is a
179178 tuple containing a condition name and a dictionary of points.
180- :return: The aggregated loss.
181- :rtype: LabelTensor
179+ :param int batch_idx: The index of the current batch.
180+ :param dict kwargs: Additional keyword arguments passed to
181+ ``optimization_cycle``.
182+ :return: The loss of the training step.
183+ :rtype: torch.Tensor
182184 """
183185 # Weights optimization
184186 self .optimizer_weights .instance .zero_grad ()
185- loss = super ().training_step (batch )
187+ loss = self ._optimization_cycle (
188+ batch = batch , batch_idx = batch_idx , ** kwargs
189+ )
186190 self .manual_backward (- loss )
187191 self .optimizer_weights .instance .step ()
188192 self .scheduler_weights .instance .step ()
189193
190194 # Model optimization
191195 self .optimizer_model .instance .zero_grad ()
192- loss = super ().training_step (batch )
196+ loss = self ._optimization_cycle (
197+ batch = batch , batch_idx = batch_idx , ** kwargs
198+ )
193199 self .manual_backward (loss )
194200 self .optimizer_model .instance .step ()
195201 self .scheduler_model .instance .step ()
196202
203+ # Log the loss
204+ self .store_log ("train_loss" , loss , self .get_batch_size (batch ))
205+
197206 return loss
198207
199- def configure_optimizers (self ):
208+ @torch .set_grad_enabled (True )
209+ def validation_step (self , batch , ** kwargs ):
200210 """
201- Optimizer configuration.
211+ The validation step for the Self-Adaptive PINN solver. It returns the
212+ average residual computed with the ``loss`` function not aggregated.
202213
203- :return: The optimizers and the schedulers
204- :rtype: tuple[list[Optimizer], list[Scheduler]]
205- """
206- # If the problem is an InverseProblem, add the unknown parameters
207- # to the parameters to be optimized
208- self .optimizer_model .hook (self .model .parameters ())
209- self .optimizer_weights .hook (self .weights_dict .parameters ())
210- if isinstance (self .problem , InverseProblem ):
211- self .optimizer_model .instance .add_param_group (
212- {
213- "params" : [
214- self ._params [var ]
215- for var in self .problem .unknown_variables
216- ]
217- }
218- )
219- self .scheduler_model .hook (self .optimizer_model )
220- self .scheduler_weights .hook (self .optimizer_weights )
221- return (
222- [self .optimizer_model .instance , self .optimizer_weights .instance ],
223- [self .scheduler_model .instance , self .scheduler_weights .instance ],
224- )
225-
226- def on_train_start (self ):
214+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
215+ tuple containing a condition name and a dictionary of points.
216+ :param dict kwargs: Additional keyword arguments passed to
217+ ``optimization_cycle``.
218+ :return: The loss of the validation step.
219+ :rtype: torch.Tensor
227220 """
228- This method is called at the start of the training process to set the
229- self-adaptive weights as parameters of the mask model.
221+ losses = self .optimization_cycle (batch = batch , ** kwargs )
230222
231- :raises NotImplementedError: If the batch size is not ``None``.
232- """
233- if self .trainer .batch_size is not None :
234- raise NotImplementedError (
235- "SelfAdaptivePINN only works with full "
236- "batch size, set batch_size=None inside "
237- "the Trainer to use the solver."
238- )
239- device = torch .device (
240- self .trainer ._accelerator_connector ._accelerator_flag
241- )
223+ # Aggregate losses for each condition
224+ for cond , loss in losses .items ():
225+ losses [cond ] = self ._apply_reduction (loss = losses [cond ])
242226
243- # Initialize the self adaptive weights only for training points
244- for (
245- condition_name ,
246- tensor ,
247- ) in self .trainer .data_module .train_dataset .input .items ():
248- self .weights_dict [condition_name ].sa_weights .data = torch .rand (
249- (tensor .shape [0 ], 1 ), device = device
250- )
251- return super ().on_train_start ()
227+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
228+ self .store_log ("val_loss" , loss , self .get_batch_size (batch ))
229+ return loss
252230
253- def on_load_checkpoint (self , checkpoint ):
231+ @torch .set_grad_enabled (True )
232+ def test_step (self , batch , ** kwargs ):
254233 """
255- Override of the Pytorch Lightning ``on_load_checkpoint`` method to
256- handle checkpoints for Self-Adaptive Weights. This method should not be
257- overridden, if not intentionally.
234+ The test step for the Self-Adaptive PINN solver. It returns the average
235+ residual computed with the ``loss`` function not aggregated.
258236
259- :param dict checkpoint: Pytorch Lightning checkpoint dict.
237+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
238+ tuple containing a condition name and a dictionary of points.
239+ :param dict kwargs: Additional keyword arguments passed to
240+ ``optimization_cycle``.
241+ :return: The loss of the test step.
242+ :rtype: torch.Tensor
260243 """
261- # First initialize self-adaptive weights with correct shape,
262- # then load the values from the checkpoint.
263- for condition_name , _ in self .problem .input_pts .items ():
264- shape = checkpoint ["state_dict" ][
265- f"_pina_models.1.{ condition_name } .sa_weights"
266- ].shape
267- self .weights_dict [condition_name ].sa_weights .data = torch .rand (
268- shape
269- )
270- return super ().on_load_checkpoint (checkpoint )
244+ losses = self .optimization_cycle (batch = batch , ** kwargs )
245+
246+ # Aggregate losses for each condition
247+ for cond , loss in losses .items ():
248+ losses [cond ] = self ._apply_reduction (loss = losses [cond ])
249+
250+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
251+ self .store_log ("test_loss" , loss , self .get_batch_size (batch ))
252+ return loss
271253
272254 def loss_phys (self , samples , equation ):
273255 """
@@ -279,47 +261,138 @@ def loss_phys(self, samples, equation):
279261 :return: The computed physics loss.
280262 :rtype: LabelTensor
281263 """
282- residual = self .compute_residual (samples , equation )
283- weights = self .weights_dict [self .current_condition_name ].forward ()
284- loss_value = self ._vectorial_loss (
285- torch .zeros_like (residual , requires_grad = True ), residual
286- )
287- return self ._vect_to_scalar (weights * loss_value )
264+ residuals = self .compute_residual (samples , equation )
265+ return self ._loss_fn (residuals , torch .zeros_like (residuals ))
288266
289267 def loss_data (self , input , target ):
290268 """
291- Compute the data loss for the PINN solver by evaluating the loss
269+ Compute the data loss for the Supervised solver by evaluating the loss
292270 between the network's output and the true solution. This method should
293271 not be overridden, if not intentionally.
294272
295273 :param input: The input to the neural network.
296- :type input: LabelTensor
274+ :type input: LabelTensor | torch.Tensor | Graph | Data
297275 :param target: The target to compare with the network's output.
298- :type target: LabelTensor
276+ :type target: LabelTensor | torch.Tensor | Graph | Data
299277 :return: The supervised loss, averaged over the number of observations.
300- :rtype: LabelTensor
278+ :rtype: LabelTensor | torch.Tensor | Graph | Data
301279 """
302280 return self ._loss_fn (self .forward (input ), target )
303281
304- def _vect_to_scalar (self , loss_value ):
282+ def forward (self , x ):
305283 """
306- Computation of the scalar loss .
284+ Forward pass .
307285
308- :param LabelTensor loss_value: the tensor of pointwise losses.
309- :raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
310- :return: The computed scalar loss.
311- :rtype: LabelTensor
286+ :param x: Input tensor.
287+ :type x: torch.Tensor | LabelTensor
288+ :return: The output of the neural network.
289+ :rtype: torch.Tensor | LabelTensor
290+ """
291+ return self .model (x )
292+
293+ def configure_optimizers (self ):
294+ """
295+ Optimizer configuration.
296+
297+ :return: The optimizers and the schedulers
298+ :rtype: tuple[list[Optimizer], list[Scheduler]]
312299 """
313- if self .loss .reduction == "mean" :
314- ret = torch .mean (loss_value )
315- elif self .loss .reduction == "sum" :
316- ret = torch .sum (loss_value )
317- else :
318- raise RuntimeError (
319- f"Invalid reduction, got { self .loss .reduction } "
320- "but expected mean or sum."
300+ # Hook the optimizers to the models
301+ self .optimizer_model .hook (self .model .parameters ())
302+ self .optimizer_weights .hook (self .weights .parameters ())
303+
304+ # Add unknown parameters to optimization list in case of InverseProblem
305+ if isinstance (self .problem , InverseProblem ):
306+ self .optimizer_model .instance .add_param_group (
307+ {
308+ "params" : [
309+ self ._params [var ]
310+ for var in self .problem .unknown_variables
311+ ]
312+ }
321313 )
322- return ret
314+
315+ # Hook the schedulers to the optimizers
316+ self .scheduler_model .hook (self .optimizer_model )
317+ self .scheduler_weights .hook (self .optimizer_weights )
318+
319+ return (
320+ [self .optimizer_model .instance , self .optimizer_weights .instance ],
321+ [self .scheduler_model .instance , self .scheduler_weights .instance ],
322+ )
323+
324+ def _optimization_cycle (self , batch , batch_idx , ** kwargs ):
325+ """
326+ Aggregate the loss for each condition in the batch.
327+
328+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
329+ tuple containing a condition name and a dictionary of points.
330+ :param int batch_idx: The index of the current batch.
331+ :param dict kwargs: Additional keyword arguments passed to
332+ ``optimization_cycle``.
333+ :return: The losses computed for all conditions in the batch, casted
334+ to a subclass of :class:`torch.Tensor`. It should return a dict
335+ containing the condition name and the associated scalar loss.
336+ :rtype: dict
337+ """
338+ # Compute non-aggregated residuals
339+ residuals = self .optimization_cycle (batch )
340+
341+ # Compute losses
342+ losses = {}
343+ for cond , res in residuals .items ():
344+
345+ weight_tensor = self .weights [cond ]()
346+
347+ # Get the correct indices for the weights. Modulus is used according
348+ # to the number of points in the condition, as in the PinaDataset.
349+ len_res = len (res )
350+ idx = torch .arange (
351+ batch_idx * len_res ,
352+ (batch_idx + 1 ) * len_res ,
353+ device = res .device ,
354+ ) % len (self .problem .input_pts [cond ])
355+
356+ # Apply the weights to the residuals
357+ losses [cond ] = self ._apply_reduction (
358+ loss = (res * weight_tensor [idx ])
359+ )
360+
361+ # Store log
362+ self .store_log (
363+ f"{ cond } _loss" , losses [cond ].item (), self .get_batch_size (batch )
364+ )
365+
366+ # Clamp unknown parameters in InverseProblem (if needed)
367+ self ._clamp_params ()
368+
369+ # Aggregate
370+ loss = self .weighting .aggregate (losses ).as_subclass (torch .Tensor )
371+
372+ return loss
373+
374+ def _apply_reduction (self , loss ):
375+ """
376+ Apply the specified reduction to the loss. The reduction is deferred
377+ until the end of the optimization cycle to allow self-adaptive weights
378+ to be applied to each point beforehand.
379+
380+ :param torch.Tensor loss: The loss tensor to be reduced.
381+ :return: The reduced loss tensor.
382+ :rtype: torch.Tensor
383+ :raises ValueError: If the reduction method is neither "mean" nor "sum".
384+ """
385+ # Apply the specified reduction method
386+ if self ._reduction == "mean" :
387+ return loss .mean ()
388+ if self ._reduction == "sum" :
389+ return loss .sum ()
390+
391+ # Raise an error if the reduction method is not recognized
392+ raise ValueError (
393+ f"Unknown reduction: { self ._reduction } ."
394+ " Supported reductions are 'mean' and 'sum'."
395+ )
323396
324397 @property
325398 def model (self ):
@@ -332,7 +405,7 @@ def model(self):
332405 return self .models [0 ]
333406
334407 @property
335- def weights_dict (self ):
408+ def weights (self ):
336409 """
337410 The self-adaptive weights.
338411
0 commit comments