|
| 1 | +"""Module for the DeepEnsemble simple solver.""" |
| 2 | + |
| 3 | +from pina._src.solver.multi_model_simple_solver import MultiModelSimpleSolver |
| 4 | + |
| 5 | + |
| 6 | +class DeepEnsembleSimpleSolver(MultiModelSimpleSolver): |
| 7 | + r""" |
| 8 | + Deep Ensemble Simple Solver class. This class implements a Deep Ensemble |
| 9 | + solver for generic conditions (data, equations, or domain residuals) using |
| 10 | + user-specified ``models`` to solve a specific ``problem``. |
| 11 | +
|
| 12 | + It is the ensemble counterpart of |
| 13 | + :class:`~pina.solver.SingleModelSimpleSolver`: each model in the ensemble |
| 14 | + evaluates every condition independently, and the per-model scalar losses |
| 15 | + are averaged to produce the final condition loss. |
| 16 | +
|
| 17 | + An ensemble model is constructed by combining multiple models that solve |
| 18 | + the same type of problem. Mathematically, this creates an implicit |
| 19 | + distribution :math:`p(\mathbf{u} \mid \mathbf{s})` over the possible |
| 20 | + outputs :math:`\mathbf{u}`, given the original input :math:`\mathbf{s}`. |
| 21 | + The models :math:`\mathcal{M}_{i\in (1,\dots,r)}` in |
| 22 | + the ensemble work collaboratively to capture different |
| 23 | + aspects of the data or task, with each model contributing a distinct |
| 24 | + prediction |
| 25 | + :math:`\mathbf{y}_{i}=\mathcal{M}_i(\mathbf{u} \mid \mathbf{s})`. |
| 26 | + By aggregating these predictions, the ensemble |
| 27 | + model can achieve greater robustness and accuracy compared to individual |
| 28 | + models, leveraging the diversity of the models to reduce overfitting and |
| 29 | + improve generalization. Furthemore, statistical metrics can |
| 30 | + be computed, e.g. the ensemble mean and variance: |
| 31 | +
|
| 32 | + .. math:: |
| 33 | + \mathbf{\mu} = \frac{1}{N}\sum_{i=1}^r \mathbf{y}_{i} |
| 34 | +
|
| 35 | + .. math:: |
| 36 | + \mathbf{\sigma^2} = \frac{1}{N}\sum_{i=1}^r |
| 37 | + (\mathbf{y}_{i} - \mathbf{\mu})^2 |
| 38 | +
|
| 39 | + During training the condition loss is minimised by each ensemble model |
| 40 | + independently and then averaged: |
| 41 | +
|
| 42 | + .. math:: |
| 43 | + \mathcal{L}_{\rm{condition}} = \frac{1}{N_{\rm{ensemble}}} |
| 44 | + \sum_{i=1}^{N_{\rm{ensemble}}} |
| 45 | + \mathcal{L}_i(\mathcal{M}_i, \mathbf{s}) |
| 46 | +
|
| 47 | + where :math:`\mathcal{L}` is a specific loss function, typically the MSE: |
| 48 | +
|
| 49 | + .. math:: |
| 50 | + \mathcal{L}(v) = \| v \|^2_2. |
| 51 | +
|
| 52 | + .. seealso:: |
| 53 | +
|
| 54 | + **Original reference**: Lakshminarayanan, B., Pritzel, A., & Blundell, |
| 55 | + C. (2017). *Simple and scalable predictive uncertainty estimation |
| 56 | + using deep ensembles*. Advances in neural information |
| 57 | + processing systems, 30. |
| 58 | + DOI: `arXiv:1612.01474 <https://arxiv.org/abs/1612.01474>`_. |
| 59 | + """ |
| 60 | + |
| 61 | + def __init__( |
| 62 | + self, |
| 63 | + problem, |
| 64 | + models, |
| 65 | + optimizers=None, |
| 66 | + schedulers=None, |
| 67 | + weighting=None, |
| 68 | + loss=None, |
| 69 | + use_lt=True, |
| 70 | + ensemble_dim=0, |
| 71 | + ): |
| 72 | + """ |
| 73 | + Initialization of the :class:`DeepEnsembleSimpleSolver` class. |
| 74 | +
|
| 75 | + :param AbstractProblem problem: The problem to be solved. |
| 76 | + :param list[torch.nn.Module] models: The neural network models to be |
| 77 | + used. Must be a list or tuple with at least two models. |
| 78 | + :param list[Optimizer] optimizers: The optimizers to be used. |
| 79 | + If ``None``, the :class:`torch.optim.Adam` optimizer is used for |
| 80 | + each model. Default is ``None``. |
| 81 | + :param list[Scheduler] schedulers: The learning rate schedulers. |
| 82 | + If ``None``, :class:`torch.optim.lr_scheduler.ConstantLR` is used |
| 83 | + for each model. Default is ``None``. |
| 84 | + :param WeightingInterface weighting: The weighting schema to be used. |
| 85 | + If ``None``, no weighting schema is used. Default is ``None``. |
| 86 | + :param torch.nn.Module loss: The element-wise loss module. |
| 87 | + If ``None``, :class:`torch.nn.MSELoss` is used. Default is |
| 88 | + ``None``. |
| 89 | + :param bool use_lt: If ``True``, the solver uses LabelTensors as |
| 90 | + input. Default is ``True``. |
| 91 | + :param int ensemble_dim: The dimension along which the per-model |
| 92 | + outputs are stacked in :meth:`forward`. Default is ``0``. |
| 93 | + """ |
| 94 | + super().__init__( |
| 95 | + problem=problem, |
| 96 | + models=models, |
| 97 | + optimizers=optimizers, |
| 98 | + schedulers=schedulers, |
| 99 | + weighting=weighting, |
| 100 | + loss=loss, |
| 101 | + use_lt=use_lt, |
| 102 | + ensemble_dim=ensemble_dim, |
| 103 | + ) |
0 commit comments