Skip to content

Commit 21c6735

Browse files
alexfiklinducer
authored andcommitted
test: add test to pickle kernels
1 parent 9464f46 commit 21c6735

1 file changed

Lines changed: 44 additions & 1 deletion

File tree

sumpy/test/test_kernels.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import logging
2727
import sys
2828
from functools import partial
29+
from typing import TYPE_CHECKING
2930

3031
import numpy as np
3132
import numpy.linalg as la
@@ -36,7 +37,7 @@
3637
PyOpenCLArrayContext,
3738
pytest_generate_tests_for_array_contexts,
3839
)
39-
from pytools import obj_array
40+
from pytools import memoize_on_first_arg, obj_array
4041
from pytools.convergence import PConvergenceVerifier
4142

4243
import sumpy.symbolic as sym
@@ -62,15 +63,22 @@
6263
AxisTargetDerivative,
6364
BiharmonicKernel,
6465
DirectionalSourceDerivative,
66+
ElasticityKernel,
6567
HelmholtzKernel,
6668
Kernel,
6769
LaplaceKernel,
70+
LineOfCompressionKernel,
71+
OneKernel,
6872
StokesletKernel,
73+
StressletKernel,
6974
YukawaKernel,
7075
)
7176
from sumpy.test.geometries import make_ellipsoid, make_torus
7277

7378

79+
if TYPE_CHECKING:
80+
from collections.abc import Callable
81+
7482
logger = logging.getLogger(__name__)
7583

7684
pytest_generate_tests = pytest_generate_tests_for_array_contexts([
@@ -857,6 +865,8 @@ def test_m2m_compressed_error_helmholtz(actx_factory: ArrayContextFactory, dim,
857865
# }}}
858866

859867

868+
# {{{ test_jump
869+
860870
@pytest.mark.parametrize(("kernel_cls", "kernel_kwargs"), [
861871
(LaplaceKernel, {}),
862872
(HelmholtzKernel, {"k": 1}),
@@ -922,6 +932,39 @@ def test_jump(
922932
err = abs((inside-outside) - -1)
923933
assert err < tol, err
924934

935+
# }}}
936+
937+
938+
# {{{ test_pickle
939+
940+
@memoize_on_first_arg
941+
def get_kernel_name_for_test(knl: Kernel) -> Callable[[str], str]:
942+
return lambda prefix: f"{prefix}: {type(knl).__name__}"
943+
944+
945+
@pytest.mark.parametrize("knl", [
946+
BiharmonicKernel(2),
947+
ElasticityKernel(2, 0, 0),
948+
HelmholtzKernel(3, helmholtz_k_name="kay"),
949+
LaplaceKernel(3),
950+
LineOfCompressionKernel(),
951+
OneKernel(2),
952+
StokesletKernel(2, 0, 0),
953+
StressletKernel(2, 0, 0, 0),
954+
YukawaKernel(2, yukawa_lambda_name="lambda"),
955+
])
956+
def test_pickle(knl: Kernel) -> None:
957+
import pickle
958+
959+
result = pickle.dumps(knl)
960+
assert pickle.loads(result) == knl
961+
962+
_ = get_kernel_name_for_test(knl)
963+
result = pickle.dumps(knl)
964+
assert pickle.loads(result) == knl
965+
966+
# }}}
967+
925968

926969
# You can test individual routines by typing
927970
# $ python test_kernels.py 'test_p2p(_acf, True)'

0 commit comments

Comments
 (0)