From 381e682c08b540ffdbd3de72e89d16001d2d4c02 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 18 Oct 2022 16:56:02 +0200 Subject: [PATCH 1/2] Implement multi-output Elemwise in Numba via guvectorize --- aesara/link/numba/dispatch/elemwise.py | 52 +++++++++++++++++++++++++- aesara/link/numba/linker.py | 1 + tests/link/numba/test_basic.py | 2 +- tests/link/numba/test_elemwise.py | 17 +++++++++ 4 files changed, 70 insertions(+), 2 deletions(-) diff --git a/aesara/link/numba/dispatch/elemwise.py b/aesara/link/numba/dispatch/elemwise.py index 6d2b360913..3cf70d847d 100644 --- a/aesara/link/numba/dispatch/elemwise.py +++ b/aesara/link/numba/dispatch/elemwise.py @@ -162,6 +162,53 @@ def create_vectorize_func( return elemwise_fn +def create_guvectorize_func( + scalar_op_fn: Callable, + node: Apply, + identity: Optional[Any] = None, + **kwargs, +) -> Callable: + r"""Create a guvectorized Numba function from a `Apply`\s Python function.""" + + signature_ = create_numba_signature(node, force_scalar=False) + signature = [(*signature_.args, *signature_.return_type.types)] + + target = ( + getattr(node.tag, "numba__vectorize_target", None) + or config.numba__vectorize_target + ) + + layout = f"{','.join(('()',) * len(node.inputs))}->{','.join(('()',) * len(node.outputs))}" + print(f"{signature=}, {layout=}") + numba_guvectorized_fn = numba.guvectorize( + signature, + layout, + identity=identity, + target=target, + fastmath=config.numba__fastmath, + ) + + input_names = [f"i{i}" for i in range(len(node.inputs))] + output_names = [f"o{i}" for i in range(len(node.outputs))] + gu_fn_name = "gu_func" + + gu_fn_src = f""" +def {gu_fn_name}({', '.join(input_names)}, {', '.join(output_names)}): + for i in range({input_names[0]}.shape[0]): + {'[i], '.join(output_names)}[i] = scalar_op_fn({'[i], '.join(input_names)}[i]) +""" + print(gu_fn_src) + + gu_fn_inner = compile_function_src( + gu_fn_src, gu_fn_name, {"scalar_op_fn": scalar_op_fn, **globals()} + ) + + gu_fn = numba_guvectorized_fn(gu_fn_inner) + # gu_fn.py_scalar_func = py_scalar_func + + return gu_fn + + def create_axis_reducer( scalar_op: Op, identity: Union[np.ndarray, Number], @@ -426,7 +473,10 @@ def axis_apply_fn(x): def numba_funcify_Elemwise(op, node, **kwargs): scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs) - elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) + if node.outputs == 1: + elemwise_fn = create_vectorize_func(scalar_op_fn, node, use_signature=False) + else: + elemwise_fn = create_guvectorize_func(scalar_op_fn, node) elemwise_fn_name = elemwise_fn.__name__ if op.inplace_pattern: diff --git a/aesara/link/numba/linker.py b/aesara/link/numba/linker.py index bb390b0523..0a1b40dd58 100644 --- a/aesara/link/numba/linker.py +++ b/aesara/link/numba/linker.py @@ -27,6 +27,7 @@ def fgraph_convert(self, fgraph, **kwargs): return numba_funcify(fgraph, **kwargs) def jit_compile(self, fn): + return fn import numba jitted_fn = numba.njit(fn) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 96ef53203d..a948d856a3 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -233,7 +233,7 @@ def assert_fn(x, y): numba_res = aesara_numba_fn(*inputs) # Get some coverage - eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) + # eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) if len(fn_outputs) > 1: for j, p in zip(numba_res, py_res): diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index c7b30c8e6a..f6c92bdc06 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import aesara.scalar as aes import aesara.tensor as at import aesara.tensor.inplace as ati import aesara.tensor.math as aem @@ -13,6 +14,7 @@ from aesara.graph.basic import Constant from aesara.graph.fg import FunctionGraph from aesara.tensor import elemwise as at_elemwise +from aesara.tensor.elemwise import Elemwise from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum from tests.link.numba.test_basic import ( compare_numba_and_py, @@ -111,6 +113,21 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): compare_numba_and_py(out_fg, input_vals) +def test_multioutput_elemwise(): + scalar_inp = aes.float64() + scalar_out1 = aes.exp(scalar_inp) + scalar_out2 = aes.log(scalar_inp) + scalar_composite = aes.Composite([scalar_inp], [scalar_out1, scalar_out2]) + + tensor_inp = at.dvector() + tensor_outs = Elemwise(scalar_composite)(tensor_inp) + + out_fg = FunctionGraph([tensor_inp], tensor_outs) + + print("") + compare_numba_and_py(out_fg, [np.r_[1.0, 2.0, 3.5]]) + + @pytest.mark.parametrize( "v, new_order", [ From 8c1ff36e97352ef7c89c8a47437fe3254c2c1258 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 26 Oct 2022 10:04:12 +0200 Subject: [PATCH 2/2] Try to get rid of looping, but this leads to NameError: global name 'scalar_op_fn' is not defined --- aesara/link/numba/dispatch/elemwise.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aesara/link/numba/dispatch/elemwise.py b/aesara/link/numba/dispatch/elemwise.py index 3cf70d847d..98412a1f43 100644 --- a/aesara/link/numba/dispatch/elemwise.py +++ b/aesara/link/numba/dispatch/elemwise.py @@ -194,8 +194,7 @@ def create_guvectorize_func( gu_fn_src = f""" def {gu_fn_name}({', '.join(input_names)}, {', '.join(output_names)}): - for i in range({input_names[0]}.shape[0]): - {'[i], '.join(output_names)}[i] = scalar_op_fn({'[i], '.join(input_names)}[i]) + {'[()], '.join(output_names)}[()] = scalar_op_fn({'[()], '.join(input_names)}[()]) """ print(gu_fn_src)