Skip to content

Commit 14e9353

Browse files
[lang] Support unary float math operations
This MR does not expose all the math operations from cuda.tile, just an initial subset of unary floating point math operations. They are also not imported by default; this is because I am not sure how we want to expose the math operations. Signed-off-by: Asher Mancinelli <amancinelli@nvidia.com>
1 parent d37b24d commit 14e9353

5 files changed

Lines changed: 252 additions & 4 deletions

File tree

experimental/cuda-lang/src/cuda/lang/_datatype.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
is_boolean,
3232
is_integral,
3333
is_signed,
34+
unsigned_integral_dtypes,
35+
signed_integral_dtypes,
3436
get_signedness,
3537
default_int_type,
3638
integer_dtype,
@@ -98,6 +100,8 @@ def to_torch_dtype(dtype: DType, /):
98100
"opaque_pointer_dtype",
99101
"get_signedness",
100102
"integer_dtype",
103+
"unsigned_integral_dtypes",
104+
"signed_integral_dtypes",
101105
"bool_",
102106
"uint8",
103107
"uint16",

experimental/cuda-lang/src/cuda/lang/_ir/ops.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,15 @@
8484
)
8585
from cuda.tile._exception import TileValueError
8686
import cuda.lang._mlir as mlir
87-
from .type_checking_helpers import require_array_indices, require_scalar_type, \
88-
require_pointer_type, require_signed_int_scalar_or_tuple, \
89-
require_clusterlaunchcontrol_token_type
87+
from .type_checking_helpers import (
88+
require_array_indices,
89+
require_scalar_or_vector_float_type,
90+
require_scalar_or_vector_type,
91+
require_scalar_type,
92+
require_pointer_type,
93+
require_signed_int_scalar_or_tuple,
94+
require_clusterlaunchcontrol_token_type,
95+
)
9096

9197
from .type import (
9298
LocalArrayContextManagerTy, ContextManagerState, TensorMapTy,
@@ -111,6 +117,7 @@
111117
format_var,
112118
LocalArrayContextManagerValue,
113119
)
120+
from .._stub import math as cl_math
114121
from .._stub.cluster_launch_control import clusterlaunchcontrol_try_cancel, \
115122
clusterlaunchcontrol_is_canceled, clusterlaunchcontrol_get_first_block_idx
116123
from .._stub.tensor_map import TensorMapSwizzle
@@ -1163,6 +1170,58 @@ class RawMLIROperation(Operation, opcode="mlir.operation",
11631170
mlir_attributes: tuple[tuple[str, mlir.Attribute], ...] = attribute(default=())
11641171

11651172

1173+
def _get_dtype(ty: ScalarTy | VectorTy):
1174+
match ty:
1175+
case ScalarTy() as st:
1176+
return st.dtype
1177+
case VectorTy() as vt:
1178+
return vt.element_dtype
1179+
case _:
1180+
assert False, "Match should have been exhaustive"
1181+
1182+
1183+
@impl(cl_math.ceil, fixed_args=["math.ceil"])
1184+
@impl(cl_math.sin, fixed_args=["math.sin"])
1185+
@impl(cl_math.cos, fixed_args=["math.cos"])
1186+
@impl(cl_math.tan, fixed_args=["math.tan"])
1187+
@impl(cl_math.sinh, fixed_args=["math.sinh"])
1188+
@impl(cl_math.cosh, fixed_args=["math.cosh"])
1189+
@impl(cl_math.tanh, fixed_args=["math.tanh"])
1190+
@impl(cl_math.sqrt, fixed_args=["math.sqrt"])
1191+
@impl(cl_math.floor, fixed_args=["math.floor"])
1192+
@impl(cl_math.log, fixed_args=["math.log"])
1193+
@impl(cl_math.log2, fixed_args=["math.log2"])
1194+
def math_float_unary_impl(op_name: str, x: Var):
1195+
x_ty = require_scalar_or_vector_float_type(x)
1196+
return add_operation(
1197+
RawMLIROperation,
1198+
x_ty,
1199+
op_name=op_name,
1200+
operands_=(x,),
1201+
)
1202+
1203+
1204+
@impl(cl_math.abs)
1205+
def abs_impl(x: Var) -> Var:
1206+
x_ty = require_scalar_or_vector_type(x)
1207+
x_dtype = _get_dtype(x_ty)
1208+
if datatype.is_float(x_dtype):
1209+
op_name = "math.absf"
1210+
elif datatype.is_integral(x_dtype):
1211+
# If it's unsigned, then the absolute value is the identity
1212+
if not datatype.is_signed(x_dtype):
1213+
return x
1214+
op_name = "math.absi"
1215+
else:
1216+
raise TileTypeError(f"abs() expects an arithmetic scalar, got {x_ty}")
1217+
return add_operation(
1218+
RawMLIROperation,
1219+
x_ty,
1220+
op_name=op_name,
1221+
operands_=(x,),
1222+
)
1223+
1224+
11661225
def _is_none(var: Var):
11671226
return var.is_constant() and var.get_constant() is None
11681227

experimental/cuda-lang/src/cuda/lang/_ir/type_checking_helpers.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from cuda.tile._ir.ops import implicit_cast
1010
from cuda.tile._ir.type import TupleTy, TupleValue
1111
from cuda.tile._datatype import is_integral, is_signed
12-
from cuda.lang._datatype import clusterlaunchcontrol_token
12+
from cuda.lang._datatype import clusterlaunchcontrol_token, is_float
1313

1414

1515
def require_array_indices(array: Var, indices: Var) -> tuple[Var, ...]:
@@ -85,6 +85,35 @@ def require_vector_type(var: Var, length: int | None = None) -> VectorTy:
8585
return ty
8686

8787

88+
def require_scalar_or_vector_float_type(var: Var) -> VectorTy | ScalarTy:
89+
ty = var.get_type()
90+
91+
def err():
92+
return make_type_checking_error(
93+
f"Expected a scalar or vector float type, but got {ty}", var
94+
)
95+
96+
match ty:
97+
case ScalarTy() as st:
98+
dtype = st.dtype
99+
case VectorTy() as vt:
100+
dtype = vt.element_dtype
101+
case _:
102+
raise err()
103+
104+
if not is_float(dtype):
105+
raise err()
106+
107+
return ty
108+
109+
110+
def require_scalar_or_vector_type(var: Var) -> VectorTy | ScalarTy:
111+
ty = var.get_type()
112+
if not isinstance(ty, ScalarTy | VectorTy):
113+
raise make_type_checking_error(f"Expected scalar or vector type but got {ty}", var)
114+
return ty
115+
116+
88117
def make_type_checking_error(message: str, culprit: Var | None = None):
89118
# TODO: recover the context similarly to _make_type_error in cutile
90119
raise TileTypeError(message)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
6+
from cuda.lang._execution import stub
7+
@stub
8+
def ceil(x, /): ...
9+
@stub
10+
def sin(x, /): ...
11+
@stub
12+
def cos(x, /): ...
13+
@stub
14+
def tan(x, /): ...
15+
@stub
16+
def sinh(x, /): ...
17+
@stub
18+
def cosh(x, /): ...
19+
@stub
20+
def tanh(x, /): ...
21+
@stub
22+
def sqrt(x, /): ...
23+
@stub
24+
def floor(x, /): ...
25+
@stub
26+
def log(x, /): ...
27+
@stub
28+
def log2(x, /): ...
29+
@stub
30+
def abs(x, /): ...
31+
32+
33+
__all__ = (
34+
"ceil",
35+
"sin",
36+
"cos",
37+
"tan",
38+
"sinh",
39+
"cosh",
40+
"tanh",
41+
"sqrt",
42+
"floor",
43+
"log",
44+
"log2",
45+
"abs",
46+
)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import cuda.lang as cl
6+
import cuda.lang._datatype as datatype
7+
import builtins
8+
import math as host_math
9+
import torch
10+
import pytest
11+
from cuda.lang import compile_simt
12+
from cuda.lang._stub import math as device_math
13+
from cuda.lang.compilation import KernelSignature
14+
from cuda.lang._exception import TileTypeError
15+
16+
17+
FLOAT_TYPES = (
18+
cl.float16,
19+
cl.float32,
20+
cl.float64,
21+
)
22+
SIGNED_INT_TYPES = datatype.signed_integral_dtypes
23+
UNSIGNED_INT_TYPES = datatype.unsigned_integral_dtypes
24+
25+
UNARY_FLOAT_OPS = (
26+
(device_math.ceil, host_math.ceil),
27+
(device_math.sin, host_math.sin),
28+
(device_math.cos, host_math.cos),
29+
(device_math.tan, host_math.tan),
30+
(device_math.sinh, host_math.sinh),
31+
(device_math.cosh, host_math.cosh),
32+
(device_math.tanh, host_math.tanh),
33+
(device_math.sqrt, host_math.sqrt),
34+
(device_math.floor, host_math.floor),
35+
(device_math.log, host_math.log),
36+
(device_math.log2, host_math.log2),
37+
(device_math.abs, builtins.abs),
38+
)
39+
40+
41+
@pytest.mark.parametrize("dtype", FLOAT_TYPES)
42+
@pytest.mark.parametrize("device_op, host_op", UNARY_FLOAT_OPS)
43+
def test_math_unary_float(dtype, device_op, host_op):
44+
rng = torch.Generator().manual_seed(0)
45+
46+
@cl.kernel
47+
def kernel(inp, out):
48+
out[0] = device_op(inp[0])
49+
50+
torch_dt = datatype.to_torch_dtype(dtype)
51+
host_inp = torch.rand((), generator=rng).item() + 0.5
52+
expected = host_op(host_inp)
53+
inp = torch.tensor([host_inp], dtype=torch_dt, device="cuda")
54+
out = torch.tensor([0.0], dtype=torch_dt, device="cuda")
55+
cl.launch(torch.cuda.current_stream(), (1,), (1,), kernel, (inp, out))
56+
assert out[0].item() == pytest.approx(expected, rel=1e-3, abs=1e-3)
57+
58+
59+
@pytest.mark.parametrize("dtype", SIGNED_INT_TYPES)
60+
@pytest.mark.parametrize("host_inp", (-5, 0, 5))
61+
def test_math_abs_signed_int(dtype, host_inp):
62+
@cl.kernel
63+
def kernel(inp, out):
64+
out[0] = device_math.abs(dtype(inp[0]))
65+
66+
torch_dt = datatype.to_torch_dtype(dtype)
67+
expected = builtins.abs(host_inp)
68+
inp = torch.tensor([host_inp], dtype=torch_dt, device="cuda")
69+
out = torch.tensor([0], dtype=torch_dt, device="cuda")
70+
cl.launch(torch.cuda.current_stream(), (1,), (1,), kernel, (inp, out))
71+
assert out[0].item() == expected
72+
73+
74+
def test_math_abs_unsigned_int():
75+
# absolute value of unsigned number should be identity
76+
@cl.kernel
77+
def kernel():
78+
device_math.abs(cl.uint32(5.0))
79+
80+
result = compile_simt(kernel, [KernelSignature([])])
81+
assert "math.abs" not in result.mlir
82+
83+
84+
def test_vector():
85+
@cl.kernel
86+
def kernel(out):
87+
with cl.local_array(4, cl.float32) as arr:
88+
arr[0] = 0.5
89+
arr[1] = 1.5
90+
arr[2] = 2.5
91+
arr[3] = 3.5
92+
v = arr.get_base_pointer().load(count=4)
93+
v = device_math.floor(v)
94+
out.get_base_pointer().store(v)
95+
96+
out = torch.zeros(4, dtype=torch.float32).cuda()
97+
cl.launch(torch.cuda.current_stream(), (1,), (1,), kernel, (out,))
98+
print(out.cpu().tolist())
99+
torch.testing.assert_close(out.cpu().tolist(), [0.0, 1.0, 2.0, 3.0])
100+
101+
102+
def test_type_error():
103+
@cl.kernel
104+
def kernel():
105+
device_math.sin(cl.int32(5.0))
106+
107+
with pytest.raises(
108+
TileTypeError, match="Expected a scalar or vector float type, but got int32"
109+
):
110+
cl.launch(torch.cuda.current_stream(), (1,), (1,), kernel, ())

0 commit comments

Comments
 (0)