Skip to content

Commit 8292542

Browse files
GiovanniCanaliFilippoOlivo
authored andcommitted
fix helmholtz problem
1 parent 12c1ff5 commit 8292542

File tree

3 files changed

+29
-18
lines changed

3 files changed

+29
-18
lines changed

pina/_src/equation/equation_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,15 @@ class Helmholtz(Equation): # pylint: disable=R0903
382382
383383
\Delta u + k u - f = 0
384384
385-
Here, :math:`k` is a parameter of the equation, while :math:`f` is the
385+
Here, :math:`k` is the squared wavenumber, while :math:`f` is the
386386
forcing term.
387387
"""
388388

389389
def __init__(self, k, forcing_term):
390390
"""
391391
Initialization of the :class:`Helmholtz` class.
392392
393-
:param k: The parameter of the equation.
393+
:param k: The squared wavenumber.
394394
:type k: float | int
395395
:param Callable forcing_term: The forcing field function, taking as
396396
input the points on which evaluation is required.

pina/_src/problem/zoo/helmholtz.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,42 +37,51 @@ class HelmholtzProblem(SpatialProblem):
3737
"boundary": Condition(domain="boundary", equation=FixedValue(0.0)),
3838
}
3939

40-
def __init__(self, alpha=3.0):
40+
def __init__(self, k=1.0, alpha_x=1, alpha_y=4):
4141
"""
4242
Initialization of the :class:`HelmholtzProblem` class.
4343
44-
:param alpha: Parameter of the forcing term. Default is 3.0.
45-
:type alpha: float | int
44+
:param k: The squared wavenumber. Default is 1.0.
45+
:type k: float | int
46+
:param int alpha_x: The frequency in the x-direction. Default is 1.
47+
:param int alpha_y: The frequency in the y-direction. Default is 4.
4648
"""
4749
super().__init__()
48-
check_consistency(alpha, (int, float))
49-
self.alpha = alpha
50+
check_consistency(k, (int, float))
51+
check_consistency(alpha_x, int)
52+
check_consistency(alpha_y, int)
53+
self.k = k
54+
self.alpha_x = alpha_x
55+
self.alpha_y = alpha_y
5056

5157
def forcing_term(input_):
5258
"""
5359
Implementation of the forcing term.
5460
"""
61+
x, y, pi = input_["x"], input_["y"], torch.pi
62+
factor = (self.alpha_x**2 + self.alpha_y**2) * pi**2
5563
return (
56-
(1 - 2 * (self.alpha * torch.pi) ** 2)
57-
* torch.sin(self.alpha * torch.pi * input_.extract("x"))
58-
* torch.sin(self.alpha * torch.pi * input_.extract("y"))
64+
(self.k - factor)
65+
* torch.sin(self.alpha_x * pi * x)
66+
* torch.sin(self.alpha_y * pi * y)
5967
)
6068

6169
self.conditions["D"] = Condition(
6270
domain="D",
63-
equation=Helmholtz(self.alpha, forcing_term),
71+
equation=Helmholtz(self.k, forcing_term),
6472
)
6573

6674
def solution(self, pts):
6775
"""
6876
Implementation of the analytical solution of the Helmholtz problem.
6977
7078
:param LabelTensor pts: Points where the solution is evaluated.
71-
:return: The analytical solution of the Poisson problem.
79+
:return: The analytical solution of the Helmholtz problem.
7280
:rtype: LabelTensor
7381
"""
74-
sol = torch.sin(self.alpha * torch.pi * pts.extract("x")) * torch.sin(
75-
self.alpha * torch.pi * pts.extract("y")
82+
x, y, pi = pts["x"], pts["y"], torch.pi
83+
sol = torch.sin(self.alpha_x * pi * x) * torch.sin(
84+
self.alpha_y * pi * y
7685
)
7786
sol.labels = self.output_variables
7887
return sol

tests/test_problem_zoo/test_helmholtz.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
from pina.problem import SpatialProblem
44

55

6-
@pytest.mark.parametrize("alpha", [1.5, 3])
7-
def test_constructor(alpha):
6+
@pytest.mark.parametrize("k", [1.5, 3])
7+
@pytest.mark.parametrize("alpha_x", [1, 3])
8+
@pytest.mark.parametrize("alpha_y", [1, 3])
9+
def test_constructor(k, alpha_x, alpha_y):
810

9-
problem = HelmholtzProblem(alpha=alpha)
11+
problem = HelmholtzProblem(k=k, alpha_x=alpha_x, alpha_y=alpha_y)
1012
problem.discretise_domain(n=10, mode="random", domains="all")
1113
assert problem.are_all_domains_discretised
1214
assert isinstance(problem, SpatialProblem)
1315
assert hasattr(problem, "conditions")
1416
assert isinstance(problem.conditions, dict)
1517

1618
with pytest.raises(ValueError):
17-
HelmholtzProblem(alpha="invalid")
19+
HelmholtzProblem(k=1, alpha_x=1.5, alpha_y=1)

0 commit comments

Comments
 (0)