diff --git a/examples/pulse.py b/examples/pulse.py index 74b5bb6aa..8776e6cb3 100644 --- a/examples/pulse.py +++ b/examples/pulse.py @@ -340,6 +340,8 @@ def my_rhs(t, state): help="use entropy-stable dg for inviscid terms.") parser.add_argument("--numpy", action="store_true", help="use numpy-based eager actx.") + parser.add_argument("--jax", action="store_true", + help="use jax-based lazy actx.") parser.add_argument("--restart_file", help="root name of restart file") parser.add_argument("--casename", help="casename to use for i/o") args = parser.parse_args() @@ -354,7 +356,8 @@ def my_rhs(t, state): from mirgecom.array_context import get_reasonable_array_context_class actx_class = get_reasonable_array_context_class( - lazy=args.lazy, distributed=True, profiling=args.profiling, numpy=args.numpy) + lazy=args.lazy, distributed=True, profiling=args.profiling, + numpy=args.numpy, jax=args.jax) logging.basicConfig(format="%(message)s", level=logging.INFO) if args.casename: diff --git a/mirgecom/array_context.py b/mirgecom/array_context.py index 423244c2f..9fbabb9e0 100644 --- a/mirgecom/array_context.py +++ b/mirgecom/array_context.py @@ -45,11 +45,27 @@ def get_reasonable_array_context_class(*, lazy: bool, distributed: bool, - profiling: bool, numpy: bool = False) -> Type[ArrayContext]: + profiling: bool, numpy: bool = False, + jax: bool = False) -> Type[ArrayContext]: """Return a :class:`~arraycontext.ArrayContext` with the given constraints.""" if lazy and profiling: raise ValueError("Can't specify both lazy and profiling") + if jax: + if numpy: + raise ValueError("Can't specify both jax and numpy") + if profiling: + raise ValueError("Can't specify both jax and profiling") + if not lazy: + raise ValueError("jax needs lazy") + + if distributed: + from grudge.array_context import MPIPytatoJAXArrayContext + return MPIPytatoJAXArrayContext + else: + from grudge.array_context import PytatoJAXArrayContext + return PytatoJAXArrayContext + if numpy: if profiling: raise ValueError("Can't specify both numpy and profiling") @@ -100,6 +116,12 @@ def actx_class_is_numpy(actx_class: Type[ArrayContext]) -> bool: return issubclass(actx_class, NumpyArrayContext) +def actx_class_is_jax(actx_class: Type[ArrayContext]) -> bool: + """Return True if *actx_class* is jax-based.""" + from grudge.array_context import EagerJAXArrayContext + return issubclass(actx_class, EagerJAXArrayContext) + + def actx_class_is_distributed(actx_class: Type[ArrayContext]) -> bool: """Return True if *actx_class* is distributed.""" from grudge.array_context import MPIBasedArrayContext @@ -306,7 +328,13 @@ def initialize_actx( if comm: actx_kwargs["mpi_communicator"] = comm - if actx_class_is_numpy(actx_class): + if actx_class_is_jax(actx_class): + from grudge.array_context import MPIEagerJAXArrayContext + if comm: + assert issubclass(actx_class, MPIEagerJAXArrayContext) + else: + assert not issubclass(actx_class, MPIEagerJAXArrayContext) + elif actx_class_is_numpy(actx_class): from grudge.array_context import MPINumpyArrayContext if comm: assert issubclass(actx_class, MPINumpyArrayContext)