Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion examples/pulse.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def my_rhs(t, state):
help="use entropy-stable dg for inviscid terms.")
parser.add_argument("--numpy", action="store_true",
help="use numpy-based eager actx.")
parser.add_argument("--jax", action="store_true",
help="use jax-based lazy actx.")
parser.add_argument("--restart_file", help="root name of restart file")
parser.add_argument("--casename", help="casename to use for i/o")
args = parser.parse_args()
Expand All @@ -354,7 +356,8 @@ def my_rhs(t, state):

from mirgecom.array_context import get_reasonable_array_context_class
actx_class = get_reasonable_array_context_class(
lazy=args.lazy, distributed=True, profiling=args.profiling, numpy=args.numpy)
lazy=args.lazy, distributed=True, profiling=args.profiling,
numpy=args.numpy, jax=args.jax)

logging.basicConfig(format="%(message)s", level=logging.INFO)
if args.casename:
Expand Down
32 changes: 30 additions & 2 deletions mirgecom/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,27 @@


def get_reasonable_array_context_class(*, lazy: bool, distributed: bool,
profiling: bool, numpy: bool = False) -> Type[ArrayContext]:
profiling: bool, numpy: bool = False,
jax: bool = False) -> Type[ArrayContext]:
"""Return a :class:`~arraycontext.ArrayContext` with the given constraints."""
if lazy and profiling:
raise ValueError("Can't specify both lazy and profiling")

if jax:
if numpy:
raise ValueError("Can't specify both jax and numpy")
if profiling:
raise ValueError("Can't specify both jax and profiling")
if not lazy:
raise ValueError("jax needs lazy")

if distributed:
from grudge.array_context import MPIPytatoJAXArrayContext
return MPIPytatoJAXArrayContext
else:
from grudge.array_context import PytatoJAXArrayContext
return PytatoJAXArrayContext

if numpy:
if profiling:
raise ValueError("Can't specify both numpy and profiling")
Expand Down Expand Up @@ -100,6 +116,12 @@ def actx_class_is_numpy(actx_class: Type[ArrayContext]) -> bool:
return issubclass(actx_class, NumpyArrayContext)


def actx_class_is_jax(actx_class: Type[ArrayContext]) -> bool:
"""Return True if *actx_class* is jax-based."""
from grudge.array_context import EagerJAXArrayContext
return issubclass(actx_class, EagerJAXArrayContext)


def actx_class_is_distributed(actx_class: Type[ArrayContext]) -> bool:
"""Return True if *actx_class* is distributed."""
from grudge.array_context import MPIBasedArrayContext
Expand Down Expand Up @@ -306,7 +328,13 @@ def initialize_actx(
if comm:
actx_kwargs["mpi_communicator"] = comm

if actx_class_is_numpy(actx_class):
if actx_class_is_jax(actx_class):
from grudge.array_context import MPIEagerJAXArrayContext
if comm:
assert issubclass(actx_class, MPIEagerJAXArrayContext)
else:
assert not issubclass(actx_class, MPIEagerJAXArrayContext)
elif actx_class_is_numpy(actx_class):
from grudge.array_context import MPINumpyArrayContext
if comm:
assert issubclass(actx_class, MPINumpyArrayContext)
Expand Down
Loading