11"""Module for the Residual-Based Attention PINN solver."""
22
3- from copy import deepcopy
43import torch
54
65from .pinn import PINN
@@ -73,7 +72,6 @@ def __init__(
7372 optimizer = None ,
7473 scheduler = None ,
7574 weighting = None ,
76- loss = None ,
7775 eta = 0.001 ,
7876 gamma = 0.999 ,
7977 ):
@@ -90,99 +88,193 @@ def __init__(
9088 scheduler is used. Default is ``None``.
9189 :param WeightingInterface weighting: The weighting schema to be used.
9290 If ``None``, no weighting schema is used. Default is ``None``.
93- :param torch.nn.Module loss: The loss function to be minimized.
94- If ``None``, the :class:`torch.nn.MSELoss` loss is used.
95- Default is `None`.
9691 :param float | int eta: The learning rate for the weights of the
9792 residuals. Default is ``0.001``.
9893 :param float gamma: The decay parameter in the update of the weights
9994 of the residuals. Must be between ``0`` and ``1``.
10095 Default is ``0.999``.
96+ :raises: ValueError if `gamma` is not in the range (0, 1).
10197 """
10298 super ().__init__ (
10399 model = model ,
104100 problem = problem ,
105101 optimizer = optimizer ,
106102 scheduler = scheduler ,
107103 weighting = weighting ,
108- loss = loss ,
104+ loss = torch . nn . MSELoss ( reduction = "none" ) ,
109105 )
110106
111107 # check consistency
112108 check_consistency (eta , (float , int ))
113109 check_consistency (gamma , float )
114- assert (
115- 0 < gamma < 1
116- ), f"Invalid range: expected 0 < gamma < 1, got { gamma = } "
110+
111+ # Validate range for gamma
112+ if not 0 < gamma < 1 :
113+ raise ValueError (
114+ f"Invalid range: expected 0 < gamma < 1, but got { gamma } "
115+ )
116+
117+ # Initialize parameters
117118 self .eta = eta
118119 self .gamma = gamma
119120
120- # initialize weights
121- self .weights = {}
122- for condition_name in problem .conditions :
123- self .weights [condition_name ] = 0
121+ # Initialize the weight of each point to 0
122+ self .weights = {
123+ cond : torch .zeros ((len (data ), 1 ), device = self .device )
124+ for cond , data in self .problem .input_pts .items ()
125+ }
124126
125- # define vectorial loss
126- self ._vectorial_loss = deepcopy (self .loss )
127- self ._vectorial_loss .reduction = "none"
128-
129- # for now RBAPINN is implemented only for batch_size = None
130127 def on_train_start (self ):
131128 """
132129 Hook method called at the beginning of training.
133-
134- :raises NotImplementedError: If the batch size is not ``None``.
135130 """
136- if self .trainer .batch_size is not None :
137- raise NotImplementedError (
138- "RBAPINN only works with full batch "
139- "size, set batch_size=None inside the "
140- "Trainer to use the solver."
141- )
131+ device = self .trainer .strategy .root_device
132+ for cond in self .weights :
133+ self .weights [cond ] = self .weights [cond ].to (device )
142134 return super ().on_train_start ()
143135
144- def _vect_to_scalar (self , loss_value ):
136+ def training_step (self , batch , batch_idx , ** kwargs ):
137+ """
138+ Solver training step. It computes the optimization cycle and aggregates
139+ the losses using the ``weighting`` attribute.
140+
141+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
142+ tuple containing a condition name and a dictionary of points.
143+ :param int batch_idx: The index of the current batch.
144+ :param dict kwargs: Additional keyword arguments passed to
145+ ``optimization_cycle``.
146+ :return: The loss of the training step.
147+ :rtype: torch.Tensor
148+ """
149+ loss = self ._optimization_cycle (
150+ batch = batch , batch_idx = batch_idx , ** kwargs
151+ )
152+ self .store_log ("train_loss" , loss , self .get_batch_size (batch ))
153+ return loss
154+
155+ @torch .set_grad_enabled (True )
156+ def validation_step (self , batch , ** kwargs ):
157+ """
158+ The validation step for the PINN solver. It returns the average residual
159+ computed with the ``loss`` function not aggregated.
160+
161+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
162+ tuple containing a condition name and a dictionary of points.
163+ :param dict kwargs: Additional keyword arguments passed to
164+ ``optimization_cycle``.
165+ :return: The loss of the validation step.
166+ :rtype: torch.Tensor
167+ """
168+ losses = self .optimization_cycle (batch = batch , ** kwargs )
169+
170+ # Aggregate losses for each condition
171+ for cond , loss in losses .items ():
172+ losses [cond ] = losses [cond ].mean ()
173+
174+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
175+ self .store_log ("val_loss" , loss , self .get_batch_size (batch ))
176+ return loss
177+
178+ @torch .set_grad_enabled (True )
179+ def test_step (self , batch , ** kwargs ):
145180 """
146- Computation of the scalar loss.
181+ The test step for the PINN solver. It returns the average residual
182+ computed with the ``loss`` function not aggregated.
147183
148- :param LabelTensor loss_value: the tensor of pointwise losses.
149- :raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
150- :return: The computed scalar loss.
151- :rtype: LabelTensor
184+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
185+ tuple containing a condition name and a dictionary of points.
186+ :param dict kwargs: Additional keyword arguments passed to
187+ ``optimization_cycle``.
188+ :return: The loss of the test step.
189+ :rtype: torch.Tensor
152190 """
153- if self .loss .reduction == "mean" :
154- ret = torch .mean (loss_value )
155- elif self .loss .reduction == "sum" :
156- ret = torch .sum (loss_value )
157- else :
158- raise RuntimeError (
159- f"Invalid reduction, got { self .loss .reduction } "
160- "but expected mean or sum."
191+ losses = self .optimization_cycle (batch = batch , ** kwargs )
192+
193+ # Aggregate losses for each condition
194+ for cond , loss in losses .items ():
195+ losses [cond ] = losses [cond ].mean ()
196+
197+ loss = (sum (losses .values ()) / len (losses )).as_subclass (torch .Tensor )
198+ self .store_log ("test_loss" , loss , self .get_batch_size (batch ))
199+ return loss
200+
201+ def _optimization_cycle (self , batch , batch_idx , ** kwargs ):
202+ """
203+ Aggregate the loss for each condition in the batch.
204+
205+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
206+ tuple containing a condition name and a dictionary of points.
207+ :param int batch_idx: The index of the current batch.
208+ :param dict kwargs: Additional keyword arguments passed to
209+ ``optimization_cycle``.
210+ :return: The losses computed for all conditions in the batch, casted
211+ to a subclass of :class:`torch.Tensor`. It should return a dict
212+ containing the condition name and the associated scalar loss.
213+ :rtype: dict
214+ """
215+ # compute non-aggregated residuals
216+ residuals = self .optimization_cycle (batch )
217+
218+ # update weights based on residuals
219+ self ._update_weights (batch , batch_idx , residuals )
220+
221+ # compute losses
222+ losses = {}
223+ for cond , res in residuals .items ():
224+
225+ # Get the correct indices for the weights. Modulus is used according
226+ # to the number of points in the condition, as in the PinaDataset.
227+ len_res = len (res )
228+ idx = torch .arange (
229+ batch_idx * len_res ,
230+ (batch_idx + 1 ) * len_res ,
231+ device = res .device ,
232+ ) % len (self .problem .input_pts [cond ])
233+
234+ losses [cond ] = (res * self .weights [cond ][idx ]).mean ()
235+
236+ # store log
237+ self .store_log (
238+ f"{ cond } _loss" , losses [cond ].item (), self .get_batch_size (batch )
161239 )
162- return ret
163240
164- def loss_phys (self , samples , equation ):
241+ # clamp unknown parameters in InverseProblem (if needed)
242+ self ._clamp_params ()
243+
244+ # aggregate
245+ loss = self .weighting .aggregate (losses ).as_subclass (torch .Tensor )
246+
247+ return loss
248+
249+ def _update_weights (self , batch , batch_idx , residuals ):
165250 """
166- Computes the physics loss for the physics-informed solver based on the
167- provided samples and equation.
251+ Update weights based on residuals.
168252
169- :param LabelTensor samples: The samples to evaluate the physics loss.
170- :param EquationInterface equation: The governing equation.
171- :return: The computed physics loss.
172- :rtype: LabelTensor
253+ :param list[tuple[str, dict]] batch: A batch of data. Each element is a
254+ tuple containing a condition name and a dictionary of points.
255+ :param int batch_idx: The index of the current batch.
256+ :param dict residuals: A dictionary containing the residuals for each
257+ condition. The keys are the condition names and the values are the
258+ residuals as tensors.
173259 """
174- residual = self . compute_residual ( samples = samples , equation = equation )
175- cond = self . current_condition_name
260+ # Iterate over each condition in the batch
261+ for cond , data in batch :
176262
177- r_norm = (
178- self .eta
179- * torch .abs (residual )
180- / (torch .max (torch .abs (residual )) + 1e-12 )
181- )
182- self .weights [cond ] = (self .gamma * self .weights [cond ] + r_norm ).detach ()
263+ # Compute normalized residuals
264+ res = residuals [cond ]
265+ res_abs = res .abs ()
266+ r_norm = (self .eta * res_abs ) / (res_abs .max () + 1e-12 )
183267
184- loss_value = self ._vectorial_loss (
185- torch .zeros_like (residual , requires_grad = True ), residual
186- )
268+ # Get the correct indices for the weights. Modulus is used according
269+ # to the number of points in the condition, as in the PinaDataset.
270+ len_pts = len (data ["input" ])
271+ idx = torch .arange (
272+ batch_idx * len_pts ,
273+ (batch_idx + 1 ) * len_pts ,
274+ device = res .device ,
275+ ) % len (self .problem .input_pts [cond ])
187276
188- return self ._vect_to_scalar (self .weights [cond ] ** 2 * loss_value )
277+ # Update weights
278+ weights = self .weights [cond ]
279+ update = self .gamma * weights [idx ] + r_norm
280+ weights [idx ] = update .detach ()
0 commit comments