Skip to content
9 changes: 9 additions & 0 deletions docs/source/_rst/loss/ntk_weighting.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
NeuralTangentKernelWeighting
Comment thread
dario-coscia marked this conversation as resolved.
=============================
.. currentmodule:: pina.loss.ntk_weighting

.. automodule:: pina.loss.ntk_weighting

.. autoclass:: NeuralTangentKernelWeighting
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions pina/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
"PowerLoss",
"WeightingInterface",
"ScalarWeighting",
"NeuralTangentKernelWeighting",
]

from .loss_interface import LossInterface
from .power_loss import PowerLoss
from .lp_loss import LpLoss
from .weighting_interface import WeightingInterface
from .scalar_weighting import ScalarWeighting
from .ntk_weighting import NeuralTangentKernelWeighting
71 changes: 71 additions & 0 deletions pina/loss/ntk_weighting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Module for Neural Tangent Kernel Class"""

import torch
from torch.nn import Module
from .weighting_interface import WeightingInterface
from ..utils import check_consistency


class NeuralTangentKernelWeighting(WeightingInterface):
"""
A neural tangent kernel scheme for weighting different losses to
boost the convergence.

.. seealso::

**Original reference**: Wang, Sifan, Xinling Yu, and
Paris Perdikaris. *When and why PINNs fail to train:
A neural tangent kernel perspective*. Journal of
Computational Physics 449 (2022): 110768.
DOI: `10.1016/j.jcp.2021.110768 <https://doi.org/10.1016/j.jcp.2021.110768>`_.



"""

def __init__(self, model, alpha=0.5):
"""
Initialization of the :class:`NeuralTangentKernelWeighting` class.

:param torch.nn.Module model: The neural network model.
:param float alpha: The alpha parameter.
"""

super().__init__()
check_consistency(alpha, float)
check_consistency(model, Module)
if alpha < 0 or alpha > 1:
raise ValueError("alpha should be a value between 0 and 1")
self.alpha = alpha
self.model = model
self.weights = {}
self.default_value_weights = 1

def aggregate(self, losses):
"""
Weights the losses according to the Neural Tangent Kernel
algorithm.

:param dict(torch.Tensor) input: The dictionary of losses.
:return: The losses aggregation. It should be a scalar Tensor.
:rtype: torch.Tensor
"""
losses_norm = {}
for condition in losses:
losses[condition].backward(retain_graph=True)
grads = []
for param in self.model.parameters():
grads.append(param.grad.view(-1))
grads = torch.cat(grads)
losses_norm[condition] = torch.norm(grads)
self.weights = {
condition: self.alpha
* self.weights.get(condition, self.default_value_weights)
+ (1 - self.alpha)
* losses_norm[condition]
/ sum(losses_norm.values())
for condition in losses
}
return sum(
Comment thread
AleDinve marked this conversation as resolved.
self.weights[condition] * loss for condition, loss in losses.items()
)
65 changes: 65 additions & 0 deletions tests/test_weighting/test_ntk_weighting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
from pina import Trainer
from pina.solver import PINN
from pina.model import FeedForward
from pina.problem.zoo import Poisson2DSquareProblem
from pina.loss import NeuralTangentKernelWeighting

problem = Poisson2DSquareProblem()
condition_names = problem.conditions.keys()


@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_constructor(model, alpha):
NeuralTangentKernelWeighting(model=model, alpha=alpha)


@pytest.mark.parametrize("model", [0.5])
def test_wrong_constructor1(model):
with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model)


@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
1.2,
)
],
)
def test_wrong_constructor2(model, alpha):
with pytest.raises(ValueError):
NeuralTangentKernelWeighting(model, alpha)


@pytest.mark.parametrize(
"model,alpha",
[
(
FeedForward(
len(problem.input_variables), len(problem.output_variables)
),
0.5,
)
],
)
def test_train_aggregation(model, alpha):
weighting = NeuralTangentKernelWeighting(model=model, alpha=alpha)
problem.discretise_domain(50)
solver = PINN(problem=problem, model=model, weighting=weighting)
trainer = Trainer(solver=solver, max_epochs=5, accelerator="cpu")
trainer.train()
Loading