Skip to content

Commit dd3eeaa

Browse files
ShawnL00inducer
authored andcommitted
test for m2m
1 parent 118ae83 commit dd3eeaa

1 file changed

Lines changed: 276 additions & 0 deletions

File tree

sumpy/test/test_m2m.py

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
from __future__ import annotations
2+
3+
4+
__copyright__ = """
5+
Copyright (C) 2026 Shawn/Chaoqi Lin
6+
"""
7+
8+
__license__ = """
9+
Permission is hereby granted, free of charge, to any person obtaining a copy
10+
of this software and associated documentation files (the "Software"), to deal
11+
in the Software without restriction, including without limitation the rights
12+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13+
copies of the Software, and to permit persons to whom the Software is
14+
furnished to do so, subject to the following conditions:
15+
16+
The above copyright notice and this permission notice shall be included in
17+
all copies or substantial portions of the Software.
18+
19+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25+
THE SOFTWARE.
26+
"""
27+
28+
29+
import math
30+
import sys
31+
from typing import TYPE_CHECKING
32+
33+
import numpy as np
34+
import pytest
35+
import scipy.special as spsp
36+
import sympy as sp
37+
38+
from arraycontext import (
39+
pytest_generate_tests_for_array_contexts,
40+
)
41+
42+
import sumpy.toys as t
43+
from sumpy.array_context import PytestPyOpenCLArrayContextFactory, _acf # noqa: F401
44+
from sumpy.expansion.local import (
45+
LinearPDEConformingVolumeTaylorLocalExpansion,
46+
)
47+
from sumpy.expansion.multipole import (
48+
LinearPDEConformingVolumeTaylorMultipoleExpansion,
49+
VolumeTaylorMultipoleExpansion,
50+
)
51+
from sumpy.kernel import (
52+
BiharmonicKernel,
53+
HelmholtzKernel,
54+
Kernel,
55+
LaplaceKernel,
56+
YukawaKernel,
57+
)
58+
from sumpy.tools import build_matrix
59+
60+
61+
if TYPE_CHECKING:
62+
from collections.abc import Mapping
63+
64+
from arraycontext import ArrayContextFactory
65+
66+
67+
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
68+
PytestPyOpenCLArrayContextFactory,
69+
])
70+
71+
72+
def to_scalar(val):
73+
"""Convert symbolic or array value to scalar."""
74+
if hasattr(val, "evalf"):
75+
val = val.evalf()
76+
if hasattr(val, "item"):
77+
val = val.item()
78+
return complex(val)
79+
80+
81+
class NumericMatVecOperator:
82+
"""Wrapper for symbolic matrix-vector operator with numeric
83+
substitution."""
84+
85+
def __init__(self, symbolic_op, repl_dict):
86+
self.symbolic_op = symbolic_op
87+
self.repl_dict = repl_dict
88+
self.shape = symbolic_op.shape
89+
90+
def matvec(self, vec):
91+
result = self.symbolic_op.matvec(vec)
92+
out = []
93+
for expr in result:
94+
if hasattr(expr, "xreplace"):
95+
out.append(complex(expr.xreplace(self.repl_dict).evalf()))
96+
else:
97+
out.append(complex(expr))
98+
return np.array(out)
99+
100+
101+
def get_repl_dict(kernel, extra_kwargs):
102+
"""Numeric substitution for symbolic kernel parameters."""
103+
repl_dict = {}
104+
if "lam" in extra_kwargs:
105+
repl_dict[sp.Symbol("lam")] = extra_kwargs["lam"]
106+
if "k" in extra_kwargs:
107+
repl_dict[sp.Symbol("k")] = extra_kwargs["k"]
108+
return repl_dict
109+
110+
111+
@pytest.mark.parametrize("knl,extra_kwargs", [
112+
(LaplaceKernel(2), {}),
113+
(YukawaKernel(2), {"lam": 0.1}),
114+
(HelmholtzKernel(2), {"k": 0.5}),
115+
(BiharmonicKernel(2), {}),
116+
])
117+
def test_m2m_coeffs(
118+
actx_factory: ArrayContextFactory,
119+
knl: Kernel,
120+
extra_kwargs: Mapping[str, float],
121+
verbose: bool = True,
122+
):
123+
"""
124+
Compares two approaches:
125+
1. Compress coefficients -> embed -> M2M translate
126+
2. M2M translate with full coefficients
127+
128+
Verifies the difference formula between these two approaches.
129+
"""
130+
order = 7
131+
dim = 2
132+
repl_dict = get_repl_dict(knl, extra_kwargs)
133+
global_const = to_scalar(knl.get_global_scaling_const())
134+
135+
# Set up source, centers, and target
136+
source = np.array([[0.0], [0.1]])
137+
strength = np.array([1.0])
138+
139+
m_center1 = np.array([0.0, 0.0])
140+
offset_direction = np.array([-0.5, 0.25])
141+
c2_c1_dist = 0.1
142+
m_center2 = m_center1 + c2_c1_dist * offset_direction
143+
h = m_center2 - m_center1
144+
145+
target = np.array([[2.0], [2.0]])
146+
147+
if verbose:
148+
print(f"M2M Coefficient Verification for {type(knl).__name__}:")
149+
print(f"m_center1 = {m_center1}")
150+
print(f"m_center2 = {m_center2}")
151+
print(f"h = m_center2 - m_center1 = {h}")
152+
print()
153+
print(f"{'k':>3s} | {'ν(k)':>15s} | {'|ν(k)|':6s} | " # noqa: RUF001
154+
f"{'difference by formula':>31s} | "
155+
f"{'difference by direct computation':>31s} | "
156+
f"{'abs err':>10s}")
157+
print("-" * 120)
158+
159+
actx = actx_factory()
160+
161+
toy_ctx_full = t.ToyContext(
162+
knl,
163+
mpole_expn_class=VolumeTaylorMultipoleExpansion,
164+
extra_kernel_kwargs=extra_kwargs
165+
)
166+
167+
toy_ctx_local = t.ToyContext(
168+
knl,
169+
local_expn_class=LinearPDEConformingVolumeTaylorLocalExpansion,
170+
extra_kernel_kwargs=extra_kwargs
171+
)
172+
173+
p_full = t.PointSources(toy_ctx_full, source, weights=strength)
174+
p2m_full = t.multipole_expand(actx, p_full, m_center1, order=order, rscale=1.0)
175+
176+
p_local = t.PointSources(toy_ctx_local, m_center2.reshape(2, 1), weights=strength)
177+
p2l = t.local_expand(actx, p_local, target, order=order)
178+
179+
mexpn = LinearPDEConformingVolumeTaylorMultipoleExpansion(knl, order)
180+
181+
# Build matrix M
182+
wrangler = mexpn.expansion_terms_wrangler
183+
M_symbolic = wrangler.get_projection_matrix(rscale=1.0) # noqa: N806
184+
numeric_op = NumericMatVecOperator(M_symbolic, repl_dict)
185+
M = build_matrix(numeric_op, dtype=np.complex128) # noqa: N806
186+
coeffs_full = (M @ p2l.coeffs) * global_const
187+
188+
# Get coefficient identifiers
189+
stored_identifiers = mexpn.get_coefficient_identifiers()
190+
full_identifiers = mexpn.get_full_coefficient_identifiers()
191+
is_stored = [mi in stored_identifiers for mi in full_identifiers]
192+
stored_indices = [i for i, st in enumerate(is_stored) if st]
193+
194+
mexpn_full = VolumeTaylorMultipoleExpansion(knl, order)
195+
mexpn_full_idx = mexpn_full.get_full_coefficient_identifiers()
196+
197+
max_abs_error = 0.0
198+
199+
for k, nu_k in enumerate(full_identifiers):
200+
k_card = sum(np.array(nu_k))
201+
# assume all coefficient values are 1
202+
alpha_k = 1
203+
204+
true_k_idx = mexpn_full_idx.index(nu_k)
205+
basis_full = np.zeros(len(mexpn_full_idx), dtype=np.complex128)
206+
basis_full[true_k_idx] = alpha_k
207+
p2m_full_k = p2m_full.with_coeffs(basis_full)
208+
209+
# M^T @ alpha
210+
basis_cmp = np.zeros(M.shape[0], dtype=np.complex128)
211+
basis_cmp[stored_indices] = M[k, :] * alpha_k
212+
213+
# Embed back into full basis
214+
basis_cmp_full = np.zeros(len(mexpn_full_idx), dtype=np.complex128)
215+
for i, nu_i in enumerate(full_identifiers):
216+
if basis_cmp[i] != 0:
217+
true_i_idx = mexpn_full_idx.index(nu_i)
218+
basis_cmp_full[true_i_idx] = basis_cmp[i]
219+
220+
p2m_cmp_k = p2m_full.with_coeffs(basis_cmp_full)
221+
222+
p2m2m_cmp = t.multipole_expand(
223+
actx, p2m_cmp_k, m_center2, order=order
224+
).eval(actx, target)
225+
p2m2m_full = t.multipole_expand(
226+
actx, p2m_full_k, m_center2, order=order
227+
).eval(actx, target)
228+
229+
direct_diff = (p2m2m_cmp - p2m2m_full)[0]
230+
231+
error = 0.0 + 0.0j
232+
for s, nu_js in enumerate(stored_identifiers):
233+
nu_js_card = sum(np.array(nu_js))
234+
inner_sum = 0.0 + 0.0j
235+
236+
if nu_js_card <= k_card:
237+
start_idx = math.comb(order - k_card + dim, dim)
238+
end_idx = math.comb(order - nu_js_card + dim, dim)
239+
240+
for idx in range(start_idx, end_idx):
241+
nu_l = full_identifiers[idx]
242+
nu_sum = tuple(a + b for a, b in zip(nu_l, nu_js, strict=True))
243+
244+
if nu_sum not in full_identifiers:
245+
continue
246+
247+
derivative_idx = full_identifiers.index(nu_sum)
248+
h_pow = np.prod(h ** np.array(nu_l))
249+
fact_nu_l = np.prod(spsp.factorial(nu_l))
250+
251+
inner_sum += coeffs_full[derivative_idx] * h_pow / fact_nu_l
252+
253+
error += inner_sum * M[k, s]
254+
255+
abs_err = abs(error - direct_diff)
256+
max_abs_error = max(max_abs_error, abs_err)
257+
258+
if verbose:
259+
print(f"{k:3d} | {nu_k!s:>15s} | {k_card:6d} | "
260+
f"{error.real: .8e}{error.imag:+.8e}j | "
261+
f"{direct_diff.real: .8e}{direct_diff.imag:+.8e}j | "
262+
f"{abs_err:9.2e}")
263+
264+
if verbose:
265+
print(f"\nMaximum absolute error: {max_abs_error:.2e}")
266+
267+
assert max_abs_error < 1e-15, (
268+
f"{type(knl).__name__}: error {max_abs_error:.2e}"
269+
)
270+
271+
272+
if __name__ == "__main__":
273+
if len(sys.argv) > 1:
274+
exec(sys.argv[1])
275+
else:
276+
pytest.main([__file__])

0 commit comments

Comments
 (0)