Skip to content

Commit 3690c53

Browse files
committed
loss_data in pinn interface
1 parent 2e8d833 commit 3690c53

1 file changed

Lines changed: 17 additions & 0 deletions

File tree

pina/solver/physics_informed_solver/pinn_interface.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module for the Physics-Informed Neural Network Interface."""
22

33
from abc import ABCMeta, abstractmethod
4+
from typing import override
45
import torch
56

67
from ..supervised_solver import SupervisedSolverInterface
@@ -114,6 +115,22 @@ def test_step(self, batch):
114115
"""
115116
return super().test_step(batch, loss_residuals=self._residual_loss)
116117

118+
def loss_data(self, input, target):
119+
"""
120+
Compute the data loss for the PINN solver by evaluating the loss
121+
between the network's output and the true solution. This method should
122+
be overridden by the derived class.
123+
:param LabelTensor input: The input to the neural network.
124+
:param LabelTensor target: The target to compare with the
125+
network's output.
126+
:return: The supervised loss, averaged over the number of observations.
127+
:rtype: LabelTensor
128+
"""
129+
raise NotImplementedError(
130+
"PINN is being used in a supervised learning context, but the "
131+
"'loss_data' method has not been implemented. "
132+
)
133+
117134
@abstractmethod
118135
def loss_phys(self, samples, equation):
119136
"""

0 commit comments

Comments
 (0)