Skip to content

Commit 39da68e

Browse files
add pirate network
1 parent 6d10989 commit 39da68e

6 files changed

Lines changed: 354 additions & 0 deletions

File tree

pina/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"LowRankNeuralOperator",
1414
"Spline",
1515
"GraphNeuralOperator",
16+
"PirateNet",
1617
]
1718

1819
from .feed_forward import FeedForward, ResidualFeedForward
@@ -24,3 +25,4 @@
2425
from .low_rank_neural_operator import LowRankNeuralOperator
2526
from .spline import Spline
2627
from .graph_neural_operator import GraphNeuralOperator
28+
from .pirate_network import PirateNet

pina/model/block/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"LowRankBlock",
1919
"RBFBlock",
2020
"GNOBlock",
21+
"PirateNetBlock",
2122
]
2223

2324
from .convolution_2d import ContinuousConvBlock
@@ -35,3 +36,4 @@
3536
from .low_rank_block import LowRankBlock
3637
from .rbf_block import RBFBlock
3738
from .gno_block import GNOBlock
39+
from .pirate_network_block import PirateNetBlock
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
from ...utils import check_consistency, check_positive_integer
3+
4+
5+
class PirateNetBlock(torch.nn.Module):
6+
"""
7+
The inner block of Physics-Informed residual adaptive network (PirateNet).
8+
9+
The block consists of three dense layers with dual gating operations and an
10+
adaptive residual connection. The trainable ``alpha`` parameter controls
11+
the contribution of the residual connection.
12+
13+
.. seealso::
14+
15+
**Original reference**:
16+
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
17+
*Simulating Three-dimensional Turbulence with Physics-informed Neural
18+
Networks*.
19+
DOI: `arXiv preprint arXiv:2507.08972.
20+
<https://arxiv.org/abs/2507.08972>`_
21+
"""
22+
23+
def __init__(self, inner_size, activation):
24+
"""
25+
Initialization of the :class:`PirateNetBlock` class.
26+
27+
:param int inner_size: The number of hidden units in the dense layers.
28+
:param torch.nn.Module activation: The activation function.
29+
"""
30+
super().__init__()
31+
32+
# Check consistency
33+
check_consistency(activation, torch.nn.Module, subclass=True)
34+
check_positive_integer(inner_size, strict=True)
35+
36+
# Initialize the linear transformations of the dense layers
37+
self.linear1 = torch.nn.Linear(inner_size, inner_size)
38+
self.linear2 = torch.nn.Linear(inner_size, inner_size)
39+
self.linear3 = torch.nn.Linear(inner_size, inner_size)
40+
41+
# Initialize the scales of the dense layers
42+
self.scale1 = torch.nn.Parameter(torch.zeros(inner_size))
43+
self.scale2 = torch.nn.Parameter(torch.zeros(inner_size))
44+
self.scale3 = torch.nn.Parameter(torch.zeros(inner_size))
45+
46+
# Initialize the adaptive residual connection parameter
47+
self.alpha = torch.nn.Parameter(torch.zeros(1))
48+
49+
# Initialize the activation function
50+
self.activation = activation()
51+
52+
def forward(self, x, U, V):
53+
"""
54+
Forward pass of the PirateNet block. It computes the output of the block
55+
by applying the dense layers with scaling, and combines the results with
56+
the input using the adaptive residual connection.
57+
58+
:param x: The input tensor.
59+
:type x: torch.Tensor | LabelTensor
60+
:param torch.Tensor U: The first shared gating tensor.
61+
:param torch.Tensor V: The second shared gating tensor.
62+
:return: The output tensor of the block.
63+
:rtype: torch.Tensor | LabelTensor
64+
"""
65+
# Compute the output of the first dense layer with scaling
66+
print(f"{x.shape=}")
67+
print(f"{self.linear1(x).shape=}")
68+
print(f"{self.scale1.shape=}")
69+
print(f"{torch.exp(self.scale1).shape=}")
70+
print(f"{(self.linear1(x) * torch.exp(self.scale1)).shape=}")
71+
f = self.activation(self.linear1(x) * torch.exp(self.scale1))
72+
print(f.shape, U.shape)
73+
z1 = f * U + (1 - f) * V
74+
75+
# Compute the output of the second dense layer with scaling
76+
g = self.activation(self.linear2(z1) * torch.exp(self.scale2))
77+
z2 = g * U + (1 - g) * V
78+
79+
# Compute the output of the block
80+
h = self.activation(self.linear3(z2) * torch.exp(self.scale3))
81+
return self.alpha * h + (1 - self.alpha) * x

pina/model/pirate_network.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import torch
2+
from .block import FourierFeatureEmbedding, PirateNetBlock
3+
from ..utils import check_consistency, check_positive_integer
4+
5+
6+
class PirateNet(torch.nn.Module):
7+
"""
8+
Implementation of Physics-Informed residual adaptive network (PirateNet).
9+
10+
The model consists of a Fourier feature embedding layer, multiple PirateNet
11+
blocks, and a final output layer. Each PirateNet block consist of three
12+
dense layers with dual gating mechanism and an adaptive residual connection,
13+
whose contribution is controlled by a trainable parameter ``alpha``.
14+
15+
The PirateNet, augmented with random weight factorization, is designed to
16+
mitigate spectral bias in deep networks.
17+
18+
.. seealso::
19+
20+
**Original reference**:
21+
Wang, S., Sankaran, S., Stinis., P., Perdikaris, P. (2025).
22+
*Simulating Three-dimensional Turbulence with Physics-informed Neural
23+
Networks*.
24+
DOI: `arXiv preprint arXiv:2507.08972.
25+
<https://arxiv.org/abs/2507.08972>`_
26+
"""
27+
28+
def __init__(
29+
self,
30+
input_dimension,
31+
inner_size,
32+
output_dimension,
33+
sigma,
34+
n_layers=3,
35+
activation=torch.nn.Tanh,
36+
):
37+
"""
38+
Initialization of the :class:`PirateNet` class.
39+
40+
:param int input_dimension: The number of input features.
41+
:param int inner_size: The number of hidden units in the dense layers.
42+
:param int output_dimension: The number of output features.
43+
:param float sigma: The scaling factor for the Fourier embedding. This
44+
value must reflect the granularity of the scale in the solution of
45+
the problem to be solved.
46+
:param int n_layers: The number of PirateNet blocks in the model.
47+
Default is 3.
48+
:param torch.nn.Module activation: The activation function to be used in
49+
the blocks. Default is :class:`torch.nn.Tanh`.
50+
"""
51+
super().__init__()
52+
53+
# Check consistency
54+
check_consistency(sigma, (float, int))
55+
check_consistency(activation, torch.nn.Module, subclass=True)
56+
check_positive_integer(input_dimension, strict=True)
57+
check_positive_integer(inner_size, strict=True)
58+
check_positive_integer(output_dimension, strict=True)
59+
check_positive_integer(n_layers, strict=True)
60+
61+
# Initialize the activation function
62+
self.activation = activation()
63+
64+
# Initialize the Fourier embedding
65+
self.fourier_embedding = FourierFeatureEmbedding(
66+
input_dimension=input_dimension,
67+
output_dimension=inner_size,
68+
sigma=sigma,
69+
)
70+
71+
# Initialize the shared dense layers
72+
self.linear1 = torch.nn.Linear(inner_size, inner_size)
73+
self.linear2 = torch.nn.Linear(inner_size, inner_size)
74+
75+
# Initialize the PirateNet blocks
76+
self.blocks = torch.nn.ModuleList(
77+
[
78+
PirateNetBlock(inner_size, activation)
79+
for _ in range(n_layers)
80+
]
81+
)
82+
83+
# Initialize the output layer
84+
self.output_layer = torch.nn.Linear(inner_size, output_dimension)
85+
86+
def forward(self, input_):
87+
"""
88+
Forward pass of the PirateNet model. It applies the Fourier feature
89+
embedding, computes the shared gating tensors U and V, and passes the
90+
input through each block in the network. Finally, it applies the output
91+
layer to produce the final output.
92+
93+
:param input_: The input tensor for the model.
94+
:type input_: torch.Tensor | LabelTensor
95+
:return: The output tensor of the model.
96+
:rtype: torch.Tensor | LabelTensor
97+
"""
98+
# Apply the Fourier feature embedding
99+
x = self.fourier_embedding(input_)
100+
101+
# Compute U and V from the shared dense layers
102+
U = self.activation(self.linear1(x))
103+
V = self.activation(self.linear2(x))
104+
105+
# Pass through each block in the network
106+
for block in self.blocks:
107+
x = block(x, U, V)
108+
109+
return self.output_layer(x)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
import pytest
3+
from pina.model.block import PirateNetBlock
4+
5+
data = torch.rand((20, 3))
6+
7+
8+
@pytest.mark.parametrize("inner_size", [10, 20])
9+
def test_constructor(inner_size):
10+
11+
PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
12+
13+
# Should fail if inner_size is negative
14+
with pytest.raises(AssertionError):
15+
PirateNetBlock(inner_size=-1, activation=torch.nn.Tanh)
16+
17+
18+
@pytest.mark.parametrize("inner_size", [10, 20])
19+
def test_forward(inner_size):
20+
21+
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
22+
23+
# Create dummy embedding
24+
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
25+
x = dummy_embedding(data)
26+
27+
# Create dummy U and V tensors
28+
U = torch.rand((data.shape[0], inner_size))
29+
V = torch.rand((data.shape[0], inner_size))
30+
31+
output_ = model(x, U, V)
32+
assert output_.shape == (data.shape[0], inner_size)
33+
34+
35+
@pytest.mark.parametrize("inner_size", [10, 20])
36+
def test_backward(inner_size):
37+
38+
model = PirateNetBlock(inner_size=inner_size, activation=torch.nn.Tanh)
39+
data.requires_grad_()
40+
41+
# Create dummy embedding
42+
dummy_embedding = torch.nn.Linear(data.shape[1], inner_size)
43+
x = dummy_embedding(data)
44+
45+
# Create dummy U and V tensors
46+
U = torch.rand((data.shape[0], inner_size))
47+
V = torch.rand((data.shape[0], inner_size))
48+
49+
output_ = model(x, U, V)
50+
51+
loss = torch.mean(output_)
52+
loss.backward()
53+
assert data.grad.shape == data.shape
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import torch
2+
import pytest
3+
from pina.model import PirateNet
4+
5+
data = torch.rand((20, 3))
6+
7+
8+
@pytest.mark.parametrize("inner_size", [10, 20])
9+
@pytest.mark.parametrize("n_layers", [1, 3])
10+
@pytest.mark.parametrize("sigma", [0.1, 1, 10])
11+
@pytest.mark.parametrize("output_dimension", [2, 4])
12+
def test_constructor(inner_size, n_layers, sigma, output_dimension):
13+
14+
PirateNet(
15+
input_dimension=data.shape[1],
16+
inner_size=inner_size,
17+
output_dimension=output_dimension,
18+
sigma=sigma,
19+
n_layers=n_layers,
20+
activation=torch.nn.Tanh,
21+
)
22+
23+
# Should fail if input_dimension is negative
24+
with pytest.raises(AssertionError):
25+
PirateNet(
26+
input_dimension=-1,
27+
inner_size=inner_size,
28+
output_dimension=output_dimension,
29+
sigma=sigma,
30+
n_layers=n_layers,
31+
activation=torch.nn.Tanh,
32+
)
33+
34+
# Should fail if inner_size is negative
35+
with pytest.raises(AssertionError):
36+
PirateNet(
37+
input_dimension=data.shape[1],
38+
inner_size=-1,
39+
output_dimension=output_dimension,
40+
sigma=sigma,
41+
n_layers=n_layers,
42+
activation=torch.nn.Tanh,
43+
)
44+
45+
# Should fail if output_dimension is negative
46+
with pytest.raises(AssertionError):
47+
PirateNet(
48+
input_dimension=data.shape[1],
49+
inner_size=inner_size,
50+
output_dimension=-1,
51+
sigma=sigma,
52+
n_layers=n_layers,
53+
activation=torch.nn.Tanh,
54+
)
55+
56+
# Should fail if n_layers is negative
57+
with pytest.raises(AssertionError):
58+
PirateNet(
59+
input_dimension=data.shape[1],
60+
inner_size=inner_size,
61+
output_dimension=output_dimension,
62+
sigma=sigma,
63+
n_layers=-1,
64+
activation=torch.nn.Tanh,
65+
)
66+
67+
68+
@pytest.mark.parametrize("inner_size", [10, 20])
69+
@pytest.mark.parametrize("n_layers", [1, 3])
70+
@pytest.mark.parametrize("sigma", [0.1, 1, 10])
71+
@pytest.mark.parametrize("output_dimension", [2, 4])
72+
def test_forward(inner_size, n_layers, sigma, output_dimension):
73+
74+
model = PirateNet(
75+
input_dimension=data.shape[1],
76+
inner_size=inner_size,
77+
output_dimension=output_dimension,
78+
sigma=sigma,
79+
n_layers=n_layers,
80+
activation=torch.nn.Tanh,
81+
)
82+
83+
output_ = model(data)
84+
assert output_.shape == (data.shape[0], output_dimension)
85+
86+
87+
@pytest.mark.parametrize("inner_size", [10, 20])
88+
@pytest.mark.parametrize("n_layers", [1, 3])
89+
@pytest.mark.parametrize("sigma", [0.1, 1, 10])
90+
@pytest.mark.parametrize("output_dimension", [2, 4])
91+
def test_backward(inner_size, n_layers, sigma, output_dimension):
92+
93+
model = PirateNet(
94+
input_dimension=data.shape[1],
95+
inner_size=inner_size,
96+
output_dimension=output_dimension,
97+
sigma=sigma,
98+
n_layers=n_layers,
99+
activation=torch.nn.Tanh,
100+
)
101+
102+
data.requires_grad_()
103+
output_ = model(data)
104+
105+
loss = torch.mean(output_)
106+
loss.backward()
107+
assert data.grad.shape == data.shape

0 commit comments

Comments
 (0)