Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,12 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
"""
return f

# Supporting interface for function/call tracing in actx implementations
def trace_call(self, f: Callable[..., Any],
*args, identifier=None, **kwargs):

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing return type annotation. More generally, trace_call really should permit the user to do two things:

  • Use the result of the traced call.
  • Call the trace result using new data.

I'm not sure we'll be able to do both with just a single return value.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also want the interface to apply tags to the call site.

"""Returns the result of the called function *f* with the specified args."""
return f(*args, **kwargs)

# undocumented for now
@property
@abstractmethod
Expand Down
326 changes: 314 additions & 12 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,105 @@

import abc
import sys
from typing import (Any, Callable, Union, Tuple, Type, FrozenSet, Dict, Optional,
TYPE_CHECKING)

from typing import ( # noqa
Any, Callable, Dict, FrozenSet, Tuple, Type, Union, TypeVar, Optional,
Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping,
TYPE_CHECKING
)
import numpy as np
from pytools.tag import ToTagSetConvertible, normalize_tags, Tag

from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike
from arraycontext.container.traversal import (rec_map_array_container,
with_array_context)
from arraycontext.context import (
ArrayT, ArrayContext, Array, ArrayOrContainer, ScalarLike
)
from arraycontext.container.traversal import (
rec_map_array_container,
with_array_context,
rec_keyed_map_array_container
)

from arraycontext.container import ArrayContainer, is_array_container_type
from arraycontext.metadata import NameHint
from pytools import memoize_method
from dataclasses import dataclass
from pyrsistent import pmap, PMap
import pytato as pt
# from pt.array import _get_default_axes, _get_default_tags
# from pt.tags import FunctionIdentifier
import itertools

if TYPE_CHECKING:
import pytato
# import pytato
import pyopencl as cl

if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
import pyopencl as cl # noqa: F811


import re
import logging
logger = logging.getLogger(__name__)

ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array],
ArrayContainer)
RE_ARGNAME = re.compile(r"^_pt_(\d+)$")


def _to_identifier(s: str) -> str:
return "".join(ch for ch in s if ch.isidentifier())


def _prg_id_to_kernel_name(f: Any) -> str:
if callable(f):
name = getattr(f, "__name__", "<anonymous>")
if not name.isidentifier():
return "actx_compiled_" + _to_identifier(name)
else:
return name
else:
return _to_identifier(str(f))


class _Guess():
pass


class FromArrayContextCompile(Tag):
"""
Tagged to the entrypoint kernel of every translation unit that is generated
by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`.

Typically this tag serves as a branch condition in implementing a
specialized transform strategy for kernels compiled by
:meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`.
"""


# {{{ helper classes: AbstractInputDescriptor

class AbstractInputDescriptor:
"""
Used internally in :class:`BaseLazilyCompilingFunctionCaller` to characterize
an input.
"""
def __eq__(self, other):
raise NotImplementedError

def __hash__(self):
raise NotImplementedError


@dataclass(frozen=True, eq=True)
class ScalarInputDescriptor(AbstractInputDescriptor):
dtype: np.dtype


@dataclass(frozen=True, eq=True)
class LeafArrayDescriptor(AbstractInputDescriptor):
dtype: np.dtype
shape: pt.array.ShapeType

# }}}


# {{{ tag conversion

Expand Down Expand Up @@ -169,8 +245,8 @@ def empty_like(self, ary):

# {{{ compilation

def transform_dag(self, dag: "pytato.DictOfNamedArrays"
) -> "pytato.DictOfNamedArrays":
def transform_dag(self, dag: "pt.DictOfNamedArrays"
) -> "pt.DictOfNamedArrays":
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
Expand Down Expand Up @@ -609,10 +685,21 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from .compile import LazilyPyOpenCLCompilingFunctionCaller
return LazilyPyOpenCLCompilingFunctionCaller(self, f)

def transform_dag(self, dag: "pytato.DictOfNamedArrays"
) -> "pytato.DictOfNamedArrays":
def transform_dag(self, dag: "pt.DictOfNamedArrays"
) -> "pt.DictOfNamedArrays":
import pytato as pt
dag = pt.transform.materialize_with_mpms(dag)
dag = pt.tag_all_calls_to_be_inlined(dag)

# concated_dag = \
# pt.concatenate_calls(
# dag, (lambda x: pt.tags.FunctionIdentifier("wvflux_int")
# in x.call.function.tags))

# concated_dag = \
# pt.concatenate_calls(
# dag, (lambda x: True))

return dag

def einsum(self, spec, *args, arg_names=None, tagged=()):
Expand Down Expand Up @@ -657,6 +744,85 @@ def preprocess_arg(name, arg):
for name, arg in zip(arg_names, args)
]).tagged(_preprocess_array_tags(tagged))

def trace_call(self, f: Callable[..., ReturnT],
*args: Array,
identifier: Optional[Hashable] = None,
**kwargs: Array) -> ReturnT:
"""
Returns the expressions returned after calling *f* with the arguments
*args* and keyword arguments *kwargs*. The subexpressions in the returned
expressions are outlined as a :class:`~pytato.tracing.FunctionDefinition`.

:arg identifier: A hashable object that acts as
:attr:`pytato.tags.FunctionIdentifier.identifier` for the
:class:`~pytato.tags.FunctionIdentifier` tagged to the outlined
:class:`~pytato.tracing.FunctionDefinition`. If ``None`` the function
definition is not tagged with a
:class:`~pytato.tags.FunctionIdentifier` tag, if ``_Guess`` the
function identifier is guessed from ``f.__name__``.
"""
if identifier is _Guess:
# partials might not have a __name__ attribute
identifier = getattr(f, "__name__", None)

for kw in kwargs:
if RE_ARGNAME.match(kw):
# avoid collision between argument names
raise ValueError(f"Kw argument named '{kw}' not allowed.")

arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(
args, kwargs)

# dict_of_named_arrays = {}
# output_id_to_name_in_program = {}

input_id_to_name_in_program = {
arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}"
for arg_id in arg_id_to_arg}

# Get placeholders from the ``args``, ``kwargs``.
pl_args = [_get_f_placeholder_args(arg, iarg,
input_id_to_name_in_program, actx=self)
for iarg, arg in enumerate(args)]

pl_kwargs = {kw: _get_f_placeholder_args(arg, kw,
input_id_to_name_in_program,
actx=self)
for kw, arg in kwargs.items()}

# Pass the placeholders
output_template = f(*pl_args, **pl_kwargs)
print(f"{output_template=}")

# construct the function
# function = FunctionDefinition(
# frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs),
# Map(returns),
# tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)])
# if identifier
# else frozenset())
# )
# traced_call = Call(function,
# (Map({pl.name: arg for pl, arg in zip(pl_args, args)})
# .update(Map({pl_kwargs[kw].name: arg
# for kw, arg in kwargs.items()}))),
# result_tags=Map({name: _get_default_tags()
# for name in returns}),
# result_axes=Map({name: _get_default_axes(ret.ndim)
# for name, ret in returns.items()}),
# tags=_get_default_tags(),
# )

# if isinstance(output, Array):
# return traced_call["_"]
# elif isinstance(output, tuple):
# return tuple(traced_call[f"_{iarg}"] for iarg in range(len(output)))
# elif isinstance(output, dict):
# return {kw: traced_call[kw] for kw in output}
#else:
# raise NotImplementedError(type(output))
return f(*args, **kwargs)

def clone(self):
return type(self)(self.queue, self.allocator)

Expand Down Expand Up @@ -896,4 +1062,140 @@ def clone(self):

# }}}

# {{{ utilities


def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't these functions already exist in impl.pytato.compile? What has changed?

"""
Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
array-container's component's key. Goals of this routine:

* No two different keys should have the same stringification
* Stringified key must a valid identifier according to :meth:`str.isidentifier`
* (informal) Shorter identifiers are preferred
"""
def _rec_str(key: Any) -> str:
if isinstance(key, (str, int)):
return str(key)
elif isinstance(key, tuple):
# t in '_actx_t': stands for tuple
return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt"
else:
raise NotImplementedError("Key-stringication unimplemented for "
f"'{type(key).__name__}'.")

return "_".join(_rec_str(key) for key in keys)


def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...],
kwargs: Mapping[str, Any]
) -> "Tuple[PMap[Tuple[Any, ...],\
Any],\
PMap[Tuple[Any, ...],\
AbstractInputDescriptor]\
]":
"""
Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts
mappings from argument id to argument values and from argument id to
:class:`AbstractInputDescriptor`. See
:attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's
representation.
"""
arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {}
arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {}

for kw, arg in itertools.chain(enumerate(args),
kwargs.items()):
if np.isscalar(arg):
arg_id = (kw,)
arg_id_to_arg[arg_id] = arg
arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg)))
elif is_array_container_type(arg.__class__):
def id_collector(keys, ary):
arg_id = (kw,) + keys # noqa: B023
arg_id_to_arg[arg_id] = ary # noqa: B023
arg_id_to_descr[arg_id] = LeafArrayDescriptor( # noqa: B023
np.dtype(ary.dtype), ary.shape)
return ary

rec_keyed_map_array_container(id_collector, arg)
elif isinstance(arg, pt.Array):
arg_id = (kw,)
arg_id_to_arg[arg_id] = arg
arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(arg.dtype),
arg.shape)
else:
raise ValueError("Argument to a compiled operator should be"
" either a scalar, pt.Array or an array container. Got"
f" '{arg}'.")

return pmap(arg_id_to_arg), pmap(arg_id_to_descr)


def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext):
"""
Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder`
in :meth:`LazilyCompilingFunctionCaller.__call__`.

Preprocessing here refers to:

- Metadata Inference that is supplied via *actx*\'s
:meth:`PytatoPyOpenCLArrayContext.transform_dag`.
"""
import pyopencl.array as cla
from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array,
TaggableCLArray)
if isinstance(ary, pt.Array):
dag = pt.make_dict_of_named_arrays({"_actx_out": ary})
# Transform the DAG to give metadata inference a chance to do its job
return actx.transform_dag(dag)["_actx_out"].expr
elif isinstance(ary, TaggableCLArray):
return ary
elif isinstance(ary, cla.Array):
from warnings import warn
warn("Passing pyopencl.array.Array to a compiled callable"
" is deprecated and will stop working in 2023."
" Use `to_tagged_cl_array` to convert the array to"
" TaggableCLArray", DeprecationWarning, stacklevel=2)

return to_tagged_cl_array(ary,
axes=None,
tags=frozenset())
else:
raise NotImplementedError(type(ary))


def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx):
"""
Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the
placeholder version of an argument to
:attr:`BaseLazilyCompilingFunctionCaller.f`.
"""
if np.isscalar(arg):
name = arg_id_to_name[(kw,)]
return pt.make_placeholder(name, (), np.dtype(type(arg)))
elif isinstance(arg, pt.Array):
name = arg_id_to_name[(kw,)]
# Transform the DAG to give metadata inference a chance to do its job
arg = _to_input_for_compiled(arg, actx)
return pt.make_placeholder(name, arg.shape, arg.dtype,
axes=arg.axes,
tags=arg.tags)
elif is_array_container_type(arg.__class__):
def _rec_to_placeholder(keys, ary):
name = arg_id_to_name[(kw,) + keys]
# Transform the DAG to give metadata inference a chance to do its job
ary = _to_input_for_compiled(ary, actx)
return pt.make_placeholder(name,
ary.shape,
ary.dtype,
axes=ary.axes,
tags=ary.tags)

return rec_keyed_map_array_container(_rec_to_placeholder, arg)
else:
raise NotImplementedError(type(arg))

# }}}

# vim: foldmethod=marker