Skip to content

Commit 60fbd73

Browse files
add MPIEagerJaxArrayContext
1 parent 193100a commit 60fbd73

3 files changed

Lines changed: 39 additions & 3 deletions

File tree

grudge/array_context.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,11 @@
7676
_HAVE_FUSION_ACTX = False
7777

7878

79-
from arraycontext import ArrayContext, NumpyArrayContext
79+
from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext
8080
from arraycontext.container import ArrayContainer
8181
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
8282
from arraycontext.pytest import (
83+
_PytestEagerJaxArrayContextFactory,
8384
_PytestNumpyArrayContextFactory,
8485
_PytestPyOpenCLArrayContextFactoryWithClass,
8586
_PytestPytatoPyOpenCLArrayContextFactory,
@@ -428,6 +429,26 @@ def clone(self) -> Self:
428429
# }}}
429430

430431

432+
# {{{ distributed + eager jax
433+
434+
class MPIEagerJaxArrayContext(EagerJAXArrayContext, MPIBasedArrayContext):
435+
"""An array context for using distributed computation with :mod:`jax`
436+
eager evaluation.
437+
438+
.. autofunction:: __init__
439+
"""
440+
441+
def __init__(self, mpi_communicator) -> None:
442+
super().__init__()
443+
444+
self.mpi_communicator = mpi_communicator
445+
446+
def clone(self) -> Self:
447+
return type(self)(self.mpi_communicator)
448+
449+
# }}}
450+
451+
431452
# {{{ distributed + pytato array context subclasses
432453

433454
class MPIBasePytatoPyOpenCLArrayContext(
@@ -521,12 +542,23 @@ def __call__(self):
521542
return self.actx_class()
522543

523544

545+
class PytestEagerJAXArrayContextFactory(_PytestEagerJaxArrayContextFactory):
546+
actx_class = EagerJAXArrayContext
547+
548+
def __call__(self):
549+
import jax
550+
jax.config.update("jax_enable_x64", True)
551+
return self.actx_class()
552+
553+
524554
register_pytest_array_context_factory("grudge.pyopencl",
525555
PytestPyOpenCLArrayContextFactory)
526556
register_pytest_array_context_factory("grudge.pytato-pyopencl",
527557
PytestPytatoPyOpenCLArrayContextFactory)
528558
register_pytest_array_context_factory("grudge.numpy",
529559
PytestNumpyArrayContextFactory)
560+
register_pytest_array_context_factory("grudge.eager-jax",
561+
PytestEagerJAXArrayContextFactory)
530562

531563
# }}}
532564

test/test_dt_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from arraycontext import pytest_generate_tests_for_array_contexts
2828

2929
from grudge.array_context import (
30+
PytestEagerJAXArrayContextFactory,
3031
PytestNumpyArrayContextFactory,
3132
PytestPyOpenCLArrayContextFactory,
3233
PytestPytatoPyOpenCLArrayContextFactory,
@@ -36,7 +37,8 @@
3637
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
3738
[PytestPyOpenCLArrayContextFactory,
3839
PytestPytatoPyOpenCLArrayContextFactory,
39-
PytestNumpyArrayContextFactory])
40+
PytestNumpyArrayContextFactory,
41+
PytestEagerJAXArrayContextFactory])
4042

4143
import logging
4244

test/test_metrics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from meshmode.dof_array import flat_norm
3434

3535
from grudge.array_context import (
36+
PytestEagerJAXArrayContextFactory,
3637
PytestNumpyArrayContextFactory,
3738
PytestPyOpenCLArrayContextFactory,
3839
PytestPytatoPyOpenCLArrayContextFactory,
@@ -44,7 +45,8 @@
4445
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
4546
[PytestPyOpenCLArrayContextFactory,
4647
PytestPytatoPyOpenCLArrayContextFactory,
47-
PytestNumpyArrayContextFactory])
48+
PytestNumpyArrayContextFactory,
49+
PytestEagerJAXArrayContextFactory])
4850

4951

5052
# {{{ inverse metric

0 commit comments

Comments
 (0)