-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathtest_dpex_target_overload_isolation.py
More file actions
81 lines (53 loc) · 1.67 KB
/
test_dpex_target_overload_isolation.py
File metadata and controls
81 lines (53 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# SPDX-FileCopyrightText: 2024 Intel Corporation
#
# SPDX-License-Identifier: Apache-2.0
"""
Tests if dpex target overloads are not available at numba.njit and only
available at numba_dpex.dpjit.
"""
import pytest
from numba import njit, types
from numba.core import errors
from numba.extending import intrinsic, overload
from numba_dpex import dpjit
from numba_dpex.core.targets.dpjit_target import DPEX_TARGET_NAME
def foo():
return 1
@overload(foo, target=DPEX_TARGET_NAME)
def ol_foo():
return lambda: 1
@intrinsic(target=DPEX_TARGET_NAME)
def intrinsic_foo(
ty_context,
):
"""A numba "intrinsic" function to inject dpctl.SyclEvent constructor code.
Args:
ty_context (numba.core.typing.context.Context): The typing context
for the codegen.
Returns:
tuple(numba.core.typing.templates.Signature, function): A tuple of
numba function signature type and a function object.
"""
sig = types.int32(types.void)
def codegen(context, builder, sig, args: list):
return context.get_constant(types.int32, 1)
return sig, codegen
def bar():
return foo()
def intrinsic_bar():
res = intrinsic_foo()
return res
def test_dpex_overload_from_njit():
bar_njit = njit(bar)
with pytest.raises((errors.TypingError, errors.UnsupportedError)):
bar_njit()
def test_dpex_overload_from_dpjit():
bar_dpjit = dpjit(bar)
bar_dpjit()
def test_dpex_intrinsic_from_njit():
bar_njit = njit(intrinsic_bar)
with pytest.raises((errors.TypingError, errors.UnsupportedError)):
bar_njit()
def test_dpex_intrinsic_from_dpjit():
bar_dpjit = dpjit(intrinsic_bar)
bar_dpjit()