Skip to content

Commit a140784

Browse files
add pirate network
1 parent 6d10989 commit a140784

9 files changed

Lines changed: 368 additions & 0 deletions

File tree

docs/source/_rst/_code.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Models
104104
LowRankNeuralOperator <model/low_rank_neural_operator.rst>
105105
GraphNeuralOperator <model/graph_neural_operator.rst>
106106
GraphNeuralKernel <model/graph_neural_operator_integral_kernel.rst>
107+
PirateNet <model/pirate_network.rst>
107108

108109
Blocks
109110
-------------
@@ -121,6 +122,7 @@ Blocks
121122
Continuous Convolution Interface <model/block/convolution_interface.rst>
122123
Continuous Convolution Block <model/block/convolution.rst>
123124
Orthogonal Block <model/block/orthogonal.rst>
125+
PirateNet Block <model/block/pirate_network_block.rst>
124126

125127
Message Passing
126128
-------------------
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
PirateNet Block
2+
=======================================
3+
.. currentmodule:: pina.model.block.pirate_network_block
4+
5+
.. autoclass:: PirateNetBlock
6+
:members:
7+
:show-inheritance:
8+
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
PirateNet
2+
=======================
3+
.. currentmodule:: pina.model.pirate_network
4+
5+
.. autoclass:: PirateNet
6+
:members:
7+
:show-inheritance:

pina/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"LowRankNeuralOperator",
1414
"Spline",
1515
"GraphNeuralOperator",
16+
"PirateNet",
1617
]
1718

1819
from .feed_forward import FeedForward, ResidualFeedForward
@@ -24,3 +25,4 @@
2425
from .low_rank_neural_operator import LowRankNeuralOperator
2526
from .spline import Spline
2627
from .graph_neural_operator import GraphNeuralOperator
28+
from .pirate_network import PirateNet

pina/model/block/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"LowRankBlock",
1919
"RBFBlock",
2020
"GNOBlock",
21+
"PirateNetBlock",
2122
]
2223

2324
from .convolution_2d import ContinuousConvBlock
@@ -35,3 +36,4 @@
3536
from .low_rank_block import LowRankBlock
3637
from .rbf_block import RBFBlock
3738
from .gno_block import GNOBlock
39+
from .pirate_network_block import PirateNetBlock
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
from ...utils import check_consistency, check_positive_integer
3+
4+
5+
class PirateNetBlock(torch.nn.Module):
6+
"""
7+
The inner block of Physics-Informed residual adaptive network (PirateNet).
8+
9+
The block consists of three dense layers with dual gating operations and an
10+
adaptive residual connection. The trainable ``alpha`` parameter controls
11+
the contribution of the residual connection.
12+
13+
.. seealso::
14+
15+
**Original reference**:
16+
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
17+
*Simulating Three-dimensional Turbulence with Physics-informed Neural
18+
Networks*.
19+
DOI: `arXiv preprint arXiv:2507.08972.
20+
<https://arxiv.org/abs/2507.08972>`_
21+
"""
22+
23+
def __init__(self, inner_size, activation):
24+
"""
25+
Initialization of the :class:`PirateNetBlock` class.
26+
27+
:param int inner_size: The number of hidden units in the dense layers.
28+
:param torch.nn.Module activation: The activation function.
29+
"""
30+
super().__init__()
31+
32+
# Check consistency
33+
check_consistency(activation, torch.nn.Module, subclass=True)
34+
check_positive_integer(inner_size, strict=True)
35+
36+
# Initialize the linear transformations of the dense layers
37+
self.linear1 = torch.nn.Linear(inner_size, inner_size)
38+
self.linear2 = torch.nn.Linear(inner_size, inner_size)
39+
self.linear3 = torch.nn.Linear(inner_size, inner_size)
40+
41+
# Initialize the scales of the dense layers
42+
self.scale1 = torch.nn.Parameter(torch.zeros(inner_size))
43+
self.scale2 = torch.nn.Parameter(torch.zeros(inner_size))
44+
self.scale3 = torch.nn.Parameter(torch.zeros(inner_size))
45+
46+
# Initialize the adaptive residual connection parameter
47+
self.alpha = torch.nn.Parameter(torch.zeros(1))
48+
49+
# Initialize the activation function
50+
self.activation = activation()
51+
52+
def forward(self, x, U, V):
53+
"""
54+
Forward pass of the PirateNet block. It computes the output of the block
55+
by applying the dense layers with scaling, and combines the results with
56+
the input using the adaptive residual connection.
57+
58+
:param x: The input tensor.
59+
:type x: torch.Tensor | LabelTensor
60+
:param torch.Tensor U: The first shared gating tensor.
61+
:param torch.Tensor V: The second shared gating tensor.
62+
:return: The output tensor of the block.
63+
:rtype: torch.Tensor | LabelTensor
64+
"""
65+
# Compute the output of the first dense layer with scaling
66+
print(f"{x.shape=}")
67+
print(f"{self.linear1(x).shape=}")
68+
print(f"{self.scale1.shape=}")
69+
print(f"{torch.exp(self.scale1).shape=}")
70+
print(f"{(self.linear1(x) * torch.exp(self.scale1)).shape=}")
71+
f = self.activation(self.linear1(x) * torch.exp(self.scale1))
72+
print(f.shape, U.shape)
73+
z1 = f * U + (1 - f) * V
74+
75+
# Compute the output of the second dense layer with scaling
76+
g = self.activation(self.linear2(z1) * torch.exp(self.scale2))
77+
z2 = g * U + (1 - g) * V
78+
79+
# Compute the output of the block
80+
h = self.activation(self.linear3(z2) * torch.exp(self.scale3))
81+
return self.alpha * h + (1 - self.alpha) * x

pina/model/pirate_network.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
from .block import FourierFeatureEmbedding, PirateNetBlock
3+
from ..utils import check_consistency, check_positive_integer
4+
5+
6+
class PirateNet(torch.nn.Module):
7+
"""
8+
Implementation of Physics-Informed residual adaptive network (PirateNet).
9+
10+
The model consists of a Fourier feature embedding layer, multiple PirateNet
11+
blocks, and a final output layer. Each PirateNet block consist of three
12+
dense layers with dual gating mechanism and an adaptive residual connection,
13+
whose contribution is controlled by a trainable parameter ``alpha``.
14+
15+
The PirateNet, augmented with random weight factorization, is designed to
16+
mitigate spectral bias in deep networks.
17+
18+
.. seealso::
19+
20+
**Original reference**:
21+
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
22+
*Simulating Three-dimensional Turbulence with Physics-informed Neural
23+
Networks*.
24+
DOI: `arXiv preprint arXiv:2507.08972.
25+
<https://arxiv.org/abs/2507.08972>`_
26+
"""
27+
28+
def __init__(
29+
self,
30+
input_dimension,
31+
inner_size,
32+
output_dimension,
33+
sigma,
34+
n_layers=3,
35+
activation=torch.nn.Tanh,
36+
):
37+
"""
38+
Initialization of the :class:`PirateNet` class.
39+
40+
:param int input_dimension: The number of input features.
41+
:param int inner_size: The number of hidden units in the dense layers.
42+
:param int output_dimension: The number of output features.
43+
:param float sigma: The scaling factor for the Fourier embedding. This
44+
value must reflect the granularity of the scale in the solution of
45+
the problem to be solved.
46+
:param int n_layers: The number of PirateNet blocks in the model.
47+
Default is 3.
48+
:param torch.nn.Module activation: The activation function to be used in
49+
the blocks. Default is :class:`torch.nn.Tanh`.
50+
"""
51+
super().__init__()
52+
53+
# Check consistency
54+
check_consistency(sigma, (float, int))
55+
check_consistency(activation, torch.nn.Module, subclass=True)
56+
check_positive_integer(input_dimension, strict=True)
57+
check_positive_integer(inner_size, strict=True)
58+
check_positive_integer(output_dimension, strict=True)
59+
check_positive_integer(n_layers, strict=True)
60+
61+
# Initialize the activation function
62+
self.activation = activation()
63+
64+
# Initialize the Fourier embedding
65+
self.fourier_embedding = FourierFeatureEmbedding(
66+
input_dimension=input_dimension,
67+
output_dimension=inner_size,
68+
sigma=sigma,
69+
)
70+
71+
# Initialize the shared dense layers
72+
self.linear1 = torch.nn.Linear(inner_size, inner_size)
73+
self.linear2 = torch.nn.Linear(inner_size, inner_size)
74+
75+
# Initialize the PirateNet blocks
76+
self.blocks = torch.nn.ModuleList(
77+
[PirateNetBlock(inner_size, activation) for _ in range(n_layers)]
78+
)
79+
80+
# Initialize the output layer
81+
self.output_layer = torch.nn.Linear(inner_size, output_dimension)
82+
83+
def forward(self, input_):
84+
"""
85+
Forward pass of the PirateNet model. It applies the Fourier feature
86+
embedding, computes the shared gating tensors U and V, and passes the
87+
input through each block in the network. Finally, it applies the output
88+
layer to produce the final output.
89+
90+
:param input_: The input tensor for the model.
91+
:type input_: torch.Tensor | LabelTensor
92+
:return: The output tensor of the model.
93+
:rtype: torch.Tensor | LabelTensor
94+
"""
95+
# Apply the Fourier feature embedding
96+
x = self.fourier_embedding(input_)
97+
98+
# Compute U and V from the shared dense layers
99+
U = self.activation(self.linear1(x))
100+
V = self.activation(self.linear2(x))
101+
102+
# Pass through each block in the network
103+
for block in self.blocks:
104+
x = block(x, U, V)
105+
106+
return self.output_layer(x)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import pytest
3+
from pina.model.block import PirateNetBlock
4+
5+
data = torch.rand((20, 3))
6+
7+
8+
@pytest.mark.parametrize("inner_size", [10, 20])
9+
def test_constructor(inner_size):
10+
11+
PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
12+
13+
# Should fail if inner_size is negative
14+
with pytest.raises(AssertionError):
15+
PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh)
16+
17+
18+
@pytest.mark.parametrize("inner_size", [10, 20])
19+
def test_forward(inner_size):
20+
21+
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
22+
23+
# Create dummy embedding
24+
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
25+
x = dummy_embedding(data)
26+
27+
# Create dummy U and V tensors
28+
U = torch.rand((data.shape[0], inner_size))
29+
V = torch.rand((data.shape[0], inner_size))
30+
31+
output_ = model(x, U, V)
32+
assert output_.shape == (data.shape[0], inner_size)
33+
34+
35+
@pytest.mark.parametrize("inner_size", [10, 20])
36+
def test_backward(inner_size):
37+
38+
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
39+
data.requires_grad_()
40+
41+
# Create dummy embedding
42+
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
43+
x = dummy_embedding(data)
44+
45+
# Create dummy U and V tensors
46+
U = torch.rand((data.shape[0], inner_size))
47+
V = torch.rand((data.shape[0], inner_size))
48+
49+
output_ = model(x, U, V)
50+
51+
loss = torch.mean(output_)
52+
loss.backward()
53+
assert data.grad.shape == data.shape

0 commit comments

Comments
 (0)