Skip to content

Commit 049295e

Browse files
kaushikcfdinducer
authored andcommitted
adds PymbolicToASTMapper
1 parent a91e865 commit 049295e

1 file changed

Lines changed: 250 additions & 1 deletion

File tree

pymbolic/interop/ast.py

Lines changed: 250 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
__copyright__ = "Copyright (C) 2015 Andreas Kloeckner"
1+
__copyright__ = """
2+
Copyright (C) 2015 Andreas Kloeckner
3+
Copyright (C) 2022 Kaushik Kulkarni
4+
"""
25

36
__license__ = """
47
Permission is hereby granted, free of charge, to any person obtaining a copy
@@ -22,6 +25,9 @@
2225

2326
import ast
2427
import pymbolic.primitives as p
28+
from typing import Tuple, List, Any
29+
from pymbolic.typing import ExpressionT, ScalarT
30+
from pymbolic.mapper import CachedMapper
2531

2632
__doc__ = r'''
2733
@@ -252,4 +258,247 @@ def map_Tuple(self, expr): # noqa
252258

253259
# }}}
254260

261+
262+
# {{{ PymbolicToASTMapper
263+
264+
class PymbolicToASTMapper(CachedMapper):
265+
def map_variable(self, expr) -> ast.expr:
266+
return ast.Name(id=expr.name)
267+
268+
def _map_multi_children_op(self,
269+
children: Tuple[ExpressionT, ...],
270+
op_type: ast.operator) -> ast.expr:
271+
rec_children = [self.rec(child) for child in children]
272+
result = rec_children[-1]
273+
for child in rec_children[-2::-1]:
274+
result = ast.BinOp(child, op_type, result)
275+
276+
return result
277+
278+
def map_sum(self, expr: p.Sum) -> ast.expr:
279+
return self._map_multi_children_op(expr.children, ast.Add())
280+
281+
def map_product(self, expr: p.Product) -> ast.expr:
282+
return self._map_multi_children_op(expr.children, ast.Mult())
283+
284+
def map_constant(self, expr: ScalarT) -> ast.expr:
285+
import sys
286+
if isinstance(expr, bool):
287+
return ast.NameConstant(expr)
288+
else:
289+
# needed because of https://bugs.python.org/issue36280
290+
if sys.version_info < (3, 8):
291+
return ast.Num(expr)
292+
else:
293+
return ast.Constant(expr, None)
294+
295+
def map_call(self, expr: p.Call) -> ast.expr:
296+
return ast.Call(
297+
func=self.rec(expr.function),
298+
args=[self.rec(param) for param in expr.parameters])
299+
300+
def map_call_with_kwargs(self, expr) -> ast.expr:
301+
return ast.Call(
302+
func=self.rec(expr.function),
303+
args=[self.rec(param) for param in expr.parameters],
304+
keywords=[
305+
ast.keyword(
306+
arg=kw,
307+
value=self.rec(param))
308+
for kw, param in sorted(expr.kw_parameters.items())])
309+
310+
def map_subscript(self, expr) -> ast.expr:
311+
return ast.Subscript(value=self.rec(expr.aggregate),
312+
slice=self.rec(expr.index))
313+
314+
def map_lookup(self, expr) -> ast.expr:
315+
return ast.Attribute(self.rec(expr.aggregate),
316+
expr.name)
317+
318+
def map_quotient(self, expr) -> ast.expr:
319+
return self._map_multi_children_op((expr.numerator,
320+
expr.denominator),
321+
ast.Div())
322+
323+
def map_floor_div(self, expr) -> ast.expr:
324+
return self._map_multi_children_op((expr.numerator,
325+
expr.denominator),
326+
ast.FloorDiv())
327+
328+
def map_remainder(self, expr) -> ast.expr:
329+
return self._map_multi_children_op((expr.numerator,
330+
expr.denominator),
331+
ast.Mod())
332+
333+
def map_power(self, expr) -> ast.expr:
334+
return self._map_multi_children_op((expr.base,
335+
expr.exponent),
336+
ast.Pow())
337+
338+
def map_left_shift(self, expr) -> ast.expr:
339+
return self._map_multi_children_op((expr.shiftee,
340+
expr.shift),
341+
ast.LShift())
342+
343+
def map_right_shift(self, expr) -> ast.expr:
344+
return self._map_multi_children_op((expr.numerator,
345+
expr.denominator),
346+
ast.RShift())
347+
348+
def map_bitwise_not(self, expr) -> ast.expr:
349+
return ast.UnaryOp(ast.Invert(), self.rec(expr.child))
350+
351+
def map_bitwise_or(self, expr) -> ast.expr:
352+
return self._map_multi_children_op(expr.children,
353+
ast.BitOr())
354+
355+
def map_bitwise_xor(self, expr) -> ast.expr:
356+
return self._map_multi_children_op(expr.children,
357+
ast.BitXor())
358+
359+
def map_bitwise_and(self, expr) -> ast.expr:
360+
return self._map_multi_children_op(expr.children,
361+
ast.BitAnd())
362+
363+
def map_logical_not(self, expr) -> ast.expr:
364+
return ast.UnaryOp(self.rec(expr.child), ast.Not())
365+
366+
def map_logical_or(self, expr) -> ast.expr:
367+
return ast.BoolOp(ast.Or(), [self.rec(child)
368+
for child in expr.children])
369+
370+
def map_logical_and(self, expr) -> ast.expr:
371+
return ast.BoolOp(ast.And(), [self.rec(child)
372+
for child in expr.children])
373+
374+
def map_list(self, expr: List[Any]) -> ast.expr:
375+
return ast.List([self.rec(el) for el in expr])
376+
377+
def map_tuple(self, expr: Tuple[Any, ...]) -> ast.expr:
378+
return ast.Tuple([self.rec(el) for el in expr])
379+
380+
def map_if(self, expr: p.If) -> ast.expr:
381+
return ast.IfExp(test=self.rec(expr.condition),
382+
body=self.rec(expr.then),
383+
orelse=self.rec(expr.else_))
384+
385+
def map_nan(self, expr: p.NaN) -> ast.expr:
386+
if isinstance(expr.data_type(float("nan")), float):
387+
return ast.Call(
388+
ast.Name(id="float"),
389+
args=[ast.Constant("nan")],
390+
keywords=[])
391+
else:
392+
# TODO: would need attributes of NumPy
393+
raise NotImplementedError("Non-float nan not implemented")
394+
395+
def map_slice(self, expr: p.Slice) -> ast.expr:
396+
return ast.Slice(*[self.rec(child)
397+
for child in expr.children])
398+
399+
def map_numpy_array(self, expr) -> ast.expr:
400+
raise NotImplementedError
401+
402+
def map_multivector(self, expr) -> ast.expr:
403+
raise NotImplementedError
404+
405+
def map_common_subexpression(self, expr) -> ast.expr:
406+
raise NotImplementedError
407+
408+
def map_substitution(self, expr) -> ast.expr:
409+
raise NotImplementedError
410+
411+
def map_derivative(self, expr) -> ast.expr:
412+
raise NotImplementedError
413+
414+
def map_if_positive(self, expr) -> ast.expr:
415+
raise NotImplementedError
416+
417+
def map_comparison(self, expr: p.Comparison) -> ast.expr:
418+
raise NotImplementedError
419+
420+
def map_polynomial(self, expr) -> ast.expr:
421+
raise NotImplementedError
422+
423+
def map_wildcard(self, expr) -> ast.expr:
424+
raise NotImplementedError
425+
426+
def map_dot_wildcard(self, expr) -> ast.expr:
427+
raise NotImplementedError
428+
429+
def map_star_wildcard(self, expr) -> ast.expr:
430+
raise NotImplementedError
431+
432+
def map_function_symbol(self, expr) -> ast.expr:
433+
raise NotImplementedError
434+
435+
def map_min(self, expr) -> ast.expr:
436+
raise NotImplementedError
437+
438+
def map_max(self, expr) -> ast.expr:
439+
raise NotImplementedError
440+
441+
442+
def to_python_ast(expr) -> ast.expr:
443+
"""
444+
Maps *expr* to :class:`ast.expr`.
445+
"""
446+
return PymbolicToASTMapper()(expr)
447+
448+
449+
def to_evaluatable_python_function(expr: ExpressionT,
450+
fn_name: str
451+
) -> str:
452+
"""
453+
Returns a :class:`str` of the python code with a single function *fn_name*
454+
that takes in the variables in *expr* as keyword-only arguments and returns
455+
the evaluated value of *expr*.
456+
457+
.. testsetup::
458+
459+
>>> from pymbolic import parse
460+
>>> from pymbolic.interop.ast import to_evaluatable_python_function
461+
462+
.. doctest::
463+
464+
>>> expr = parse("S//32 + E%32")
465+
>>> # Skipping doctest as astunparse and ast.unparse have certain subtle
466+
>>> # differences
467+
>>> print(to_evaluatable_python_function(expr, "foo"))) # doctest: +SKIP
468+
def foo(*, E, S):
469+
return S // 32 + E % 32
470+
"""
471+
import sys
472+
from pymbolic.mapper.dependency import CachedDependencyMapper
473+
474+
if sys.version_info < (3, 9):
475+
try:
476+
from astunparse import unparse
477+
except ImportError:
478+
raise RuntimeError("'to_evaluate_python_function' needs"
479+
"astunparse for Py<3.9. Install via `pip"
480+
" install astunparse`")
481+
else:
482+
unparse = ast.unparse
483+
484+
dep_mapper = CachedDependencyMapper(composite_leaves=True)
485+
deps = sorted({dep.name for dep in dep_mapper(expr)})
486+
487+
ast_func = ast.FunctionDef(name=fn_name,
488+
args=ast.arguments(args=[],
489+
posonlyargs=[],
490+
kwonlyargs=[ast.arg(dep, None)
491+
for dep in deps],
492+
kw_defaults=[None]*len(deps),
493+
vararg=None,
494+
kwarg=None,
495+
defaults=[]),
496+
body=[ast.Return(to_python_ast(expr))],
497+
decorator_list=[])
498+
ast_module = ast.Module([ast_func], type_ignores=[])
499+
500+
return unparse(ast.fix_missing_locations(ast_module))
501+
502+
# }}}
503+
255504
# vim: foldmethod=marker

0 commit comments

Comments
 (0)