-
Notifications
You must be signed in to change notification settings - Fork 11
Initial stab at actx.trace_call #210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 1 commit
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
d2dde87
Initial stab at actx.trace_call
MTCam 198d2cf
Merge remote-tracking branch 'origin/main' into trace-call
MTCam 968b5d3
Merge with inducer/main
MTCam 5472972
add (commented) output parsing section
MTCam 61af2e1
discover content of output_template
MTCam f055987
Merge remote-tracking branch 'origin/main' into mrg-upstream
MTCam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -520,6 +520,12 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: | |
| """ | ||
| return f | ||
|
|
||
| # Supporting interface for function/call tracing in actx implementations | ||
| def trace_call(self, f: Callable[..., Any], | ||
| *args, identifier=None, **kwargs): | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you also want the interface to apply tags to the call site. |
||
| """Returns the result of the called function *f* with the specified args.""" | ||
| return f(*args, **kwargs) | ||
|
|
||
| # undocumented for now | ||
| @property | ||
| @abstractmethod | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,29 +44,105 @@ | |
|
|
||
| import abc | ||
| import sys | ||
| from typing import (Any, Callable, Union, Tuple, Type, FrozenSet, Dict, Optional, | ||
| TYPE_CHECKING) | ||
|
|
||
| from typing import ( # noqa | ||
| Any, Callable, Dict, FrozenSet, Tuple, Type, Union, TypeVar, Optional, | ||
| Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping, | ||
| TYPE_CHECKING | ||
| ) | ||
| import numpy as np | ||
| from pytools.tag import ToTagSetConvertible, normalize_tags, Tag | ||
|
|
||
| from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike | ||
| from arraycontext.container.traversal import (rec_map_array_container, | ||
| with_array_context) | ||
| from arraycontext.context import ( | ||
| ArrayT, ArrayContext, Array, ArrayOrContainer, ScalarLike | ||
| ) | ||
| from arraycontext.container.traversal import ( | ||
| rec_map_array_container, | ||
| with_array_context, | ||
| rec_keyed_map_array_container | ||
| ) | ||
|
|
||
| from arraycontext.container import ArrayContainer, is_array_container_type | ||
| from arraycontext.metadata import NameHint | ||
| from pytools import memoize_method | ||
| from dataclasses import dataclass | ||
| from pyrsistent import pmap, PMap | ||
| import pytato as pt | ||
| # from pt.array import _get_default_axes, _get_default_tags | ||
| # from pt.tags import FunctionIdentifier | ||
| import itertools | ||
|
|
||
| if TYPE_CHECKING: | ||
| import pytato | ||
| # import pytato | ||
| import pyopencl as cl | ||
|
|
||
| if getattr(sys, "_BUILDING_SPHINX_DOCS", False): | ||
| import pyopencl as cl # noqa: F811 | ||
|
|
||
|
|
||
| import re | ||
| import logging | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
| ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array], | ||
| ArrayContainer) | ||
| RE_ARGNAME = re.compile(r"^_pt_(\d+)$") | ||
|
|
||
|
|
||
| def _to_identifier(s: str) -> str: | ||
| return "".join(ch for ch in s if ch.isidentifier()) | ||
|
|
||
|
|
||
| def _prg_id_to_kernel_name(f: Any) -> str: | ||
| if callable(f): | ||
| name = getattr(f, "__name__", "<anonymous>") | ||
| if not name.isidentifier(): | ||
| return "actx_compiled_" + _to_identifier(name) | ||
| else: | ||
| return name | ||
| else: | ||
| return _to_identifier(str(f)) | ||
|
|
||
|
|
||
| class _Guess(): | ||
| pass | ||
|
|
||
|
|
||
| class FromArrayContextCompile(Tag): | ||
| """ | ||
| Tagged to the entrypoint kernel of every translation unit that is generated | ||
| by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. | ||
|
|
||
| Typically this tag serves as a branch condition in implementing a | ||
| specialized transform strategy for kernels compiled by | ||
| :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. | ||
| """ | ||
|
|
||
|
|
||
| # {{{ helper classes: AbstractInputDescriptor | ||
|
|
||
| class AbstractInputDescriptor: | ||
| """ | ||
| Used internally in :class:`BaseLazilyCompilingFunctionCaller` to characterize | ||
| an input. | ||
| """ | ||
| def __eq__(self, other): | ||
| raise NotImplementedError | ||
|
|
||
| def __hash__(self): | ||
| raise NotImplementedError | ||
|
|
||
|
|
||
| @dataclass(frozen=True, eq=True) | ||
| class ScalarInputDescriptor(AbstractInputDescriptor): | ||
| dtype: np.dtype | ||
|
|
||
|
|
||
| @dataclass(frozen=True, eq=True) | ||
| class LeafArrayDescriptor(AbstractInputDescriptor): | ||
| dtype: np.dtype | ||
| shape: pt.array.ShapeType | ||
|
|
||
| # }}} | ||
|
|
||
|
|
||
| # {{{ tag conversion | ||
|
|
||
|
|
@@ -169,8 +245,8 @@ def empty_like(self, ary): | |
|
|
||
| # {{{ compilation | ||
|
|
||
| def transform_dag(self, dag: "pytato.DictOfNamedArrays" | ||
| ) -> "pytato.DictOfNamedArrays": | ||
| def transform_dag(self, dag: "pt.DictOfNamedArrays" | ||
| ) -> "pt.DictOfNamedArrays": | ||
| """ | ||
| Returns a transformed version of *dag*. Sub-classes are supposed to | ||
| override this method to implement context-specific transformations on | ||
|
|
@@ -609,10 +685,21 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: | |
| from .compile import LazilyPyOpenCLCompilingFunctionCaller | ||
| return LazilyPyOpenCLCompilingFunctionCaller(self, f) | ||
|
|
||
| def transform_dag(self, dag: "pytato.DictOfNamedArrays" | ||
| ) -> "pytato.DictOfNamedArrays": | ||
| def transform_dag(self, dag: "pt.DictOfNamedArrays" | ||
| ) -> "pt.DictOfNamedArrays": | ||
| import pytato as pt | ||
| dag = pt.transform.materialize_with_mpms(dag) | ||
| dag = pt.tag_all_calls_to_be_inlined(dag) | ||
|
|
||
| # concated_dag = \ | ||
| # pt.concatenate_calls( | ||
| # dag, (lambda x: pt.tags.FunctionIdentifier("wvflux_int") | ||
| # in x.call.function.tags)) | ||
|
|
||
| # concated_dag = \ | ||
| # pt.concatenate_calls( | ||
| # dag, (lambda x: True)) | ||
|
|
||
| return dag | ||
|
|
||
| def einsum(self, spec, *args, arg_names=None, tagged=()): | ||
|
|
@@ -657,6 +744,85 @@ def preprocess_arg(name, arg): | |
| for name, arg in zip(arg_names, args) | ||
| ]).tagged(_preprocess_array_tags(tagged)) | ||
|
|
||
| def trace_call(self, f: Callable[..., ReturnT], | ||
| *args: Array, | ||
| identifier: Optional[Hashable] = None, | ||
| **kwargs: Array) -> ReturnT: | ||
| """ | ||
| Returns the expressions returned after calling *f* with the arguments | ||
| *args* and keyword arguments *kwargs*. The subexpressions in the returned | ||
| expressions are outlined as a :class:`~pytato.tracing.FunctionDefinition`. | ||
|
|
||
| :arg identifier: A hashable object that acts as | ||
| :attr:`pytato.tags.FunctionIdentifier.identifier` for the | ||
| :class:`~pytato.tags.FunctionIdentifier` tagged to the outlined | ||
| :class:`~pytato.tracing.FunctionDefinition`. If ``None`` the function | ||
| definition is not tagged with a | ||
| :class:`~pytato.tags.FunctionIdentifier` tag, if ``_Guess`` the | ||
| function identifier is guessed from ``f.__name__``. | ||
| """ | ||
| if identifier is _Guess: | ||
| # partials might not have a __name__ attribute | ||
| identifier = getattr(f, "__name__", None) | ||
|
|
||
| for kw in kwargs: | ||
| if RE_ARGNAME.match(kw): | ||
| # avoid collision between argument names | ||
| raise ValueError(f"Kw argument named '{kw}' not allowed.") | ||
|
|
||
| arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( | ||
| args, kwargs) | ||
|
|
||
| # dict_of_named_arrays = {} | ||
| # output_id_to_name_in_program = {} | ||
|
|
||
| input_id_to_name_in_program = { | ||
| arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" | ||
| for arg_id in arg_id_to_arg} | ||
|
|
||
| # Get placeholders from the ``args``, ``kwargs``. | ||
| pl_args = [_get_f_placeholder_args(arg, iarg, | ||
| input_id_to_name_in_program, actx=self) | ||
| for iarg, arg in enumerate(args)] | ||
|
|
||
| pl_kwargs = {kw: _get_f_placeholder_args(arg, kw, | ||
| input_id_to_name_in_program, | ||
| actx=self) | ||
| for kw, arg in kwargs.items()} | ||
|
|
||
| # Pass the placeholders | ||
| output_template = f(*pl_args, **pl_kwargs) | ||
| print(f"{output_template=}") | ||
|
|
||
| # construct the function | ||
| # function = FunctionDefinition( | ||
| # frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs), | ||
| # Map(returns), | ||
| # tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)]) | ||
| # if identifier | ||
| # else frozenset()) | ||
| # ) | ||
| # traced_call = Call(function, | ||
| # (Map({pl.name: arg for pl, arg in zip(pl_args, args)}) | ||
| # .update(Map({pl_kwargs[kw].name: arg | ||
| # for kw, arg in kwargs.items()}))), | ||
| # result_tags=Map({name: _get_default_tags() | ||
| # for name in returns}), | ||
| # result_axes=Map({name: _get_default_axes(ret.ndim) | ||
| # for name, ret in returns.items()}), | ||
| # tags=_get_default_tags(), | ||
| # ) | ||
|
|
||
| # if isinstance(output, Array): | ||
| # return traced_call["_"] | ||
| # elif isinstance(output, tuple): | ||
| # return tuple(traced_call[f"_{iarg}"] for iarg in range(len(output))) | ||
| # elif isinstance(output, dict): | ||
| # return {kw: traced_call[kw] for kw in output} | ||
| #else: | ||
| # raise NotImplementedError(type(output)) | ||
| return f(*args, **kwargs) | ||
|
|
||
| def clone(self): | ||
| return type(self)(self.queue, self.allocator) | ||
|
|
||
|
|
@@ -896,4 +1062,140 @@ def clone(self): | |
|
|
||
| # }}} | ||
|
|
||
| # {{{ utilities | ||
|
|
||
|
|
||
| def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't these functions already exist in |
||
| """ | ||
| Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an | ||
| array-container's component's key. Goals of this routine: | ||
|
|
||
| * No two different keys should have the same stringification | ||
| * Stringified key must a valid identifier according to :meth:`str.isidentifier` | ||
| * (informal) Shorter identifiers are preferred | ||
| """ | ||
| def _rec_str(key: Any) -> str: | ||
| if isinstance(key, (str, int)): | ||
| return str(key) | ||
| elif isinstance(key, tuple): | ||
| # t in '_actx_t': stands for tuple | ||
| return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" | ||
| else: | ||
| raise NotImplementedError("Key-stringication unimplemented for " | ||
| f"'{type(key).__name__}'.") | ||
|
|
||
| return "_".join(_rec_str(key) for key in keys) | ||
|
|
||
|
|
||
| def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], | ||
| kwargs: Mapping[str, Any] | ||
| ) -> "Tuple[PMap[Tuple[Any, ...],\ | ||
| Any],\ | ||
| PMap[Tuple[Any, ...],\ | ||
| AbstractInputDescriptor]\ | ||
| ]": | ||
| """ | ||
| Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts | ||
| mappings from argument id to argument values and from argument id to | ||
| :class:`AbstractInputDescriptor`. See | ||
| :attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's | ||
| representation. | ||
| """ | ||
| arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {} | ||
| arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {} | ||
|
|
||
| for kw, arg in itertools.chain(enumerate(args), | ||
| kwargs.items()): | ||
| if np.isscalar(arg): | ||
| arg_id = (kw,) | ||
| arg_id_to_arg[arg_id] = arg | ||
| arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg))) | ||
| elif is_array_container_type(arg.__class__): | ||
| def id_collector(keys, ary): | ||
| arg_id = (kw,) + keys # noqa: B023 | ||
| arg_id_to_arg[arg_id] = ary # noqa: B023 | ||
| arg_id_to_descr[arg_id] = LeafArrayDescriptor( # noqa: B023 | ||
| np.dtype(ary.dtype), ary.shape) | ||
| return ary | ||
|
|
||
| rec_keyed_map_array_container(id_collector, arg) | ||
| elif isinstance(arg, pt.Array): | ||
| arg_id = (kw,) | ||
| arg_id_to_arg[arg_id] = arg | ||
| arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(arg.dtype), | ||
| arg.shape) | ||
| else: | ||
| raise ValueError("Argument to a compiled operator should be" | ||
| " either a scalar, pt.Array or an array container. Got" | ||
| f" '{arg}'.") | ||
|
|
||
| return pmap(arg_id_to_arg), pmap(arg_id_to_descr) | ||
|
|
||
|
|
||
| def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): | ||
| """ | ||
| Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder` | ||
| in :meth:`LazilyCompilingFunctionCaller.__call__`. | ||
|
|
||
| Preprocessing here refers to: | ||
|
|
||
| - Metadata Inference that is supplied via *actx*\'s | ||
| :meth:`PytatoPyOpenCLArrayContext.transform_dag`. | ||
| """ | ||
| import pyopencl.array as cla | ||
| from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, | ||
| TaggableCLArray) | ||
| if isinstance(ary, pt.Array): | ||
| dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) | ||
| # Transform the DAG to give metadata inference a chance to do its job | ||
| return actx.transform_dag(dag)["_actx_out"].expr | ||
| elif isinstance(ary, TaggableCLArray): | ||
| return ary | ||
| elif isinstance(ary, cla.Array): | ||
| from warnings import warn | ||
| warn("Passing pyopencl.array.Array to a compiled callable" | ||
| " is deprecated and will stop working in 2023." | ||
| " Use `to_tagged_cl_array` to convert the array to" | ||
| " TaggableCLArray", DeprecationWarning, stacklevel=2) | ||
|
|
||
| return to_tagged_cl_array(ary, | ||
| axes=None, | ||
| tags=frozenset()) | ||
| else: | ||
| raise NotImplementedError(type(ary)) | ||
|
|
||
|
|
||
| def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): | ||
| """ | ||
| Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the | ||
| placeholder version of an argument to | ||
| :attr:`BaseLazilyCompilingFunctionCaller.f`. | ||
| """ | ||
| if np.isscalar(arg): | ||
| name = arg_id_to_name[(kw,)] | ||
| return pt.make_placeholder(name, (), np.dtype(type(arg))) | ||
| elif isinstance(arg, pt.Array): | ||
| name = arg_id_to_name[(kw,)] | ||
| # Transform the DAG to give metadata inference a chance to do its job | ||
| arg = _to_input_for_compiled(arg, actx) | ||
| return pt.make_placeholder(name, arg.shape, arg.dtype, | ||
| axes=arg.axes, | ||
| tags=arg.tags) | ||
| elif is_array_container_type(arg.__class__): | ||
| def _rec_to_placeholder(keys, ary): | ||
| name = arg_id_to_name[(kw,) + keys] | ||
| # Transform the DAG to give metadata inference a chance to do its job | ||
| ary = _to_input_for_compiled(ary, actx) | ||
| return pt.make_placeholder(name, | ||
| ary.shape, | ||
| ary.dtype, | ||
| axes=ary.axes, | ||
| tags=ary.tags) | ||
|
|
||
| return rec_keyed_map_array_container(_rec_to_placeholder, arg) | ||
| else: | ||
| raise NotImplementedError(type(arg)) | ||
|
|
||
| # }}} | ||
|
|
||
| # vim: foldmethod=marker | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing return type annotation. More generally,
trace_callreally should permit the user to do two things:I'm not sure we'll be able to do both with just a single return value.