|
76 | 76 | _HAVE_FUSION_ACTX = False |
77 | 77 |
|
78 | 78 |
|
79 | | -from arraycontext import ArrayContext, NumpyArrayContext |
| 79 | +from arraycontext import ArrayContext, EagerJAXArrayContext, NumpyArrayContext |
80 | 80 | from arraycontext.container import ArrayContainer |
81 | 81 | from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller |
82 | 82 | from arraycontext.pytest import ( |
| 83 | + _PytestEagerJaxArrayContextFactory, |
83 | 84 | _PytestNumpyArrayContextFactory, |
84 | 85 | _PytestPyOpenCLArrayContextFactoryWithClass, |
85 | 86 | _PytestPytatoPyOpenCLArrayContextFactory, |
@@ -428,6 +429,26 @@ def clone(self) -> Self: |
428 | 429 | # }}} |
429 | 430 |
|
430 | 431 |
|
| 432 | +# {{{ distributed + eager jax |
| 433 | + |
| 434 | +class MPIEagerJaxArrayContext(EagerJAXArrayContext, MPIBasedArrayContext): |
| 435 | + """An array context for using distributed computation with :mod:`jax` |
| 436 | + eager evaluation. |
| 437 | +
|
| 438 | + .. autofunction:: __init__ |
| 439 | + """ |
| 440 | + |
| 441 | + def __init__(self, mpi_communicator) -> None: |
| 442 | + super().__init__() |
| 443 | + |
| 444 | + self.mpi_communicator = mpi_communicator |
| 445 | + |
| 446 | + def clone(self) -> Self: |
| 447 | + return type(self)(self.mpi_communicator) |
| 448 | + |
| 449 | +# }}} |
| 450 | + |
| 451 | + |
431 | 452 | # {{{ distributed + pytato array context subclasses |
432 | 453 |
|
433 | 454 | class MPIBasePytatoPyOpenCLArrayContext( |
@@ -521,12 +542,23 @@ def __call__(self): |
521 | 542 | return self.actx_class() |
522 | 543 |
|
523 | 544 |
|
| 545 | +class PytestEagerJAXArrayContextFactory(_PytestEagerJaxArrayContextFactory): |
| 546 | + actx_class = EagerJAXArrayContext |
| 547 | + |
| 548 | + def __call__(self): |
| 549 | + import jax |
| 550 | + jax.config.update("jax_enable_x64", True) |
| 551 | + return self.actx_class() |
| 552 | + |
| 553 | + |
524 | 554 | register_pytest_array_context_factory("grudge.pyopencl", |
525 | 555 | PytestPyOpenCLArrayContextFactory) |
526 | 556 | register_pytest_array_context_factory("grudge.pytato-pyopencl", |
527 | 557 | PytestPytatoPyOpenCLArrayContextFactory) |
528 | 558 | register_pytest_array_context_factory("grudge.numpy", |
529 | 559 | PytestNumpyArrayContextFactory) |
| 560 | +register_pytest_array_context_factory("grudge.eager-jax", |
| 561 | + PytestEagerJAXArrayContextFactory) |
530 | 562 |
|
531 | 563 | # }}} |
532 | 564 |
|
|
0 commit comments