Skip to content

Commit f121329

Browse files
committed
add transform_dag implementation to pytato JAX array context in order to inline functions
1 parent 3e6886d commit f121329

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

arraycontext/impl/pytato/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,9 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
821821
An arraycontext that uses :mod:`pytato` to represent the thawed state of
822822
the arrays and compiles the expressions using
823823
:class:`pytato.target.python.JAXPythonTarget`.
824+
825+
826+
.. automethod:: transform_dag
824827
"""
825828

826829
def __init__(self,
@@ -972,6 +975,14 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
972975
from .compile import LazilyJAXCompilingFunctionCaller
973976
return LazilyJAXCompilingFunctionCaller(self, f)
974977

978+
def transform_dag(self, dag: pytato.DictOfNamedArrays
979+
) -> pytato.DictOfNamedArrays:
980+
import pytato as pt
981+
982+
dag = pt.tag_all_calls_to_be_inlined(dag)
983+
dag = pt.inline_calls(dag)
984+
return dag
985+
975986
def tag(self, tags: ToTagSetConvertible, array):
976987
def _tag(ary):
977988
import jax.numpy as jnp

0 commit comments

Comments
 (0)