-
Notifications
You must be signed in to change notification settings - Fork 103
Add pirate network #604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add pirate network #604
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| PirateNet Block | ||
| ======================================= | ||
| .. currentmodule:: pina.model.block.pirate_network_block | ||
|
|
||
| .. autoclass:: PirateNetBlock | ||
| :members: | ||
| :show-inheritance: | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| PirateNet | ||
| ======================= | ||
| .. currentmodule:: pina.model.pirate_network | ||
|
|
||
| .. autoclass:: PirateNet | ||
| :members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| """Module for the PirateNet block class.""" | ||
|
|
||
| import torch | ||
| from ...utils import check_consistency, check_positive_integer | ||
|
|
||
|
|
||
| class PirateNetBlock(torch.nn.Module): | ||
| """ | ||
| The inner block of Physics-Informed residual adaptive network (PirateNet). | ||
|
|
||
| The block consists of three dense layers with dual gating operations and an | ||
| adaptive residual connection. The trainable ``alpha`` parameter controls | ||
| the contribution of the residual connection. | ||
|
|
||
| .. seealso:: | ||
|
|
||
| **Original reference**: | ||
| Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025). | ||
| *Simulating Three-dimensional Turbulence with Physics-informed Neural | ||
| Networks*. | ||
| DOI: `arXiv preprint arXiv:2507.08972. | ||
| <https://arxiv.org/abs/2507.08972>`_ | ||
| """ | ||
|
|
||
| def __init__(self, inner_size, activation): | ||
| """ | ||
| Initialization of the :class:`PirateNetBlock` class. | ||
|
|
||
| :param int inner_size: The number of hidden units in the dense layers. | ||
| :param torch.nn.Module activation: The activation function. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| # Check consistency | ||
| check_consistency(activation, torch.nn.Module, subclass=True) | ||
| check_positive_integer(inner_size, strict=True) | ||
|
|
||
| # Initialize the linear transformations of the dense layers | ||
| self.linear1 = torch.nn.Linear(inner_size, inner_size) | ||
| self.linear2 = torch.nn.Linear(inner_size, inner_size) | ||
| self.linear3 = torch.nn.Linear(inner_size, inner_size) | ||
|
|
||
| # Initialize the scales of the dense layers | ||
| self.scale1 = torch.nn.Parameter(torch.zeros(inner_size)) | ||
| self.scale2 = torch.nn.Parameter(torch.zeros(inner_size)) | ||
| self.scale3 = torch.nn.Parameter(torch.zeros(inner_size)) | ||
|
|
||
| # Initialize the adaptive residual connection parameter | ||
| self._alpha = torch.nn.Parameter(torch.zeros(1)) | ||
|
|
||
| # Initialize the activation function | ||
| self.activation = activation() | ||
|
|
||
| def forward(self, x, U, V): | ||
| """ | ||
| Forward pass of the PirateNet block. It computes the output of the block | ||
| by applying the dense layers with scaling, and combines the results with | ||
| the input using the adaptive residual connection. | ||
|
|
||
| :param x: The input tensor. | ||
| :type x: torch.Tensor | LabelTensor | ||
| :param torch.Tensor U: The first shared gating tensor. It must have the | ||
| same shape as ``x``. | ||
| :param torch.Tensor V: The second shared gating tensor. It must have the | ||
| same shape as ``x``. | ||
| :return: The output tensor of the block. | ||
| :rtype: torch.Tensor | LabelTensor | ||
| """ | ||
| # Compute the output of the first dense layer with scaling | ||
| f = self.activation(self.linear1(x) * torch.exp(self.scale1)) | ||
| z1 = f * U + (1 - f) * V | ||
|
|
||
| # Compute the output of the second dense layer with scaling | ||
| g = self.activation(self.linear2(z1) * torch.exp(self.scale2)) | ||
| z2 = g * U + (1 - g) * V | ||
|
|
||
| # Compute the output of the block | ||
| h = self.activation(self.linear3(z2) * torch.exp(self.scale3)) | ||
| return self._alpha * h + (1 - self._alpha) * x | ||
|
|
||
| @property | ||
| def alpha(self): | ||
| """ | ||
| Return the alpha parameter. | ||
|
|
||
| :return: The alpha parameter controlling the residual connection. | ||
| :rtype: torch.nn.Parameter | ||
| """ | ||
| return self._alpha | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| """Module for the PirateNet model class.""" | ||
|
|
||
| import torch | ||
| from .block import FourierFeatureEmbedding, PirateNetBlock | ||
| from ..utils import check_consistency, check_positive_integer | ||
|
|
||
|
|
||
| class PirateNet(torch.nn.Module): | ||
| """ | ||
| Implementation of Physics-Informed residual adaptive network (PirateNet). | ||
|
|
||
| The model consists of a Fourier feature embedding layer, multiple PirateNet | ||
| blocks, and a final output layer. Each PirateNet block consist of three | ||
| dense layers with dual gating mechanism and an adaptive residual connection, | ||
| whose contribution is controlled by a trainable parameter ``alpha``. | ||
|
|
||
| The PirateNet, augmented with random weight factorization, is designed to | ||
| mitigate spectral bias in deep networks. | ||
|
|
||
| .. seealso:: | ||
|
|
||
| **Original reference**: | ||
| Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025). | ||
| *Simulating Three-dimensional Turbulence with Physics-informed Neural | ||
| Networks*. | ||
| DOI: `arXiv preprint arXiv:2507.08972. | ||
| <https://arxiv.org/abs/2507.08972>`_ | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| input_dimension, | ||
| inner_size, | ||
| output_dimension, | ||
| embedding=None, | ||
| n_layers=3, | ||
| activation=torch.nn.Tanh, | ||
| ): | ||
| """ | ||
| Initialization of the :class:`PirateNet` class. | ||
|
|
||
| :param int input_dimension: The number of input features. | ||
| :param int inner_size: The number of hidden units in the dense layers. | ||
| :param int output_dimension: The number of output features. | ||
| :param torch.nn.Module embedding: The embedding module used to transform | ||
| the input into a higher-dimensional feature space. If ``None``, a | ||
| default :class:`~pina.model.block.FourierFeatureEmbedding` with | ||
| scaling factor of 2 is used. Default is ``None``. | ||
| :param int n_layers: The number of PirateNet blocks in the model. | ||
| Default is 3. | ||
| :param torch.nn.Module activation: The activation function to be used in | ||
| the blocks. Default is :class:`torch.nn.Tanh`. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| # Check consistency | ||
| check_consistency(activation, torch.nn.Module, subclass=True) | ||
| check_positive_integer(input_dimension, strict=True) | ||
| check_positive_integer(inner_size, strict=True) | ||
| check_positive_integer(output_dimension, strict=True) | ||
| check_positive_integer(n_layers, strict=True) | ||
|
|
||
| # Initialize the activation function | ||
| self.activation = activation() | ||
|
|
||
| # Initialize the Fourier embedding | ||
| self.embedding = embedding or FourierFeatureEmbedding( | ||
| input_dimension=input_dimension, | ||
| output_dimension=inner_size, | ||
| sigma=2.0, | ||
| ) | ||
|
|
||
| # Initialize the shared dense layers | ||
| self.linear1 = torch.nn.Linear(inner_size, inner_size) | ||
| self.linear2 = torch.nn.Linear(inner_size, inner_size) | ||
|
|
||
| # Initialize the PirateNet blocks | ||
| self.blocks = torch.nn.ModuleList( | ||
| [PirateNetBlock(inner_size, activation) for _ in range(n_layers)] | ||
| ) | ||
|
|
||
| # Initialize the output layer | ||
| self.output_layer = torch.nn.Linear(inner_size, output_dimension) | ||
|
|
||
| def forward(self, input_): | ||
| """ | ||
| Forward pass of the PirateNet model. It applies the Fourier feature | ||
| embedding, computes the shared gating tensors U and V, and passes the | ||
| input through each block in the network. Finally, it applies the output | ||
| layer to produce the final output. | ||
|
|
||
| :param input_: The input tensor for the model. | ||
| :type input_: torch.Tensor | LabelTensor | ||
| :return: The output tensor of the model. | ||
| :rtype: torch.Tensor | LabelTensor | ||
| """ | ||
| # Apply the Fourier feature embedding | ||
| x = self.embedding(input_) | ||
|
|
||
| # Compute U and V from the shared dense layers | ||
| U = self.activation(self.linear1(x)) | ||
| V = self.activation(self.linear2(x)) | ||
|
|
||
| # Pass through each block in the network | ||
| for block in self.blocks: | ||
| x = block(x, U, V) | ||
|
|
||
| return self.output_layer(x) | ||
|
|
||
| @property | ||
| def alpha(self): | ||
| """ | ||
| Return the alpha values of all PirateNetBlock layers. | ||
|
|
||
| :return: A list of alpha values from each block. | ||
| :rtype: list | ||
| """ | ||
| return [block.alpha.item() for block in self.blocks] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import torch | ||
| import pytest | ||
| from pina.model.block import PirateNetBlock | ||
|
|
||
| data = torch.rand((20, 3)) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("inner_size", [10, 20]) | ||
| def test_constructor(inner_size): | ||
|
|
||
| PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh) | ||
|
|
||
| # Should fail if inner_size is negative | ||
| with pytest.raises(AssertionError): | ||
| PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("inner_size", [10, 20]) | ||
| def test_forward(inner_size): | ||
|
|
||
| model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh) | ||
|
|
||
| # Create dummy embedding | ||
| dummy_embedding = torch.nn.Linear(data.shape[1], inner_size) | ||
| x = dummy_embedding(data) | ||
|
|
||
| # Create dummy U and V tensors | ||
| U = torch.rand((data.shape[0], inner_size)) | ||
| V = torch.rand((data.shape[0], inner_size)) | ||
|
|
||
| output_ = model(x, U, V) | ||
| assert output_.shape == (data.shape[0], inner_size) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("inner_size", [10, 20]) | ||
| def test_backward(inner_size): | ||
|
|
||
| model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh) | ||
| data.requires_grad_() | ||
|
|
||
| # Create dummy embedding | ||
| dummy_embedding = torch.nn.Linear(data.shape[1], inner_size) | ||
| x = dummy_embedding(data) | ||
|
|
||
| # Create dummy U and V tensors | ||
| U = torch.rand((data.shape[0], inner_size)) | ||
| V = torch.rand((data.shape[0], inner_size)) | ||
|
|
||
| output_ = model(x, U, V) | ||
|
|
||
| loss = torch.mean(output_) | ||
| loss.backward() | ||
| assert data.grad.shape == data.shape |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.