Skip to content

Commit cbda484

Browse files
committed
floordiv to support float operands
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent 8959f49 commit cbda484

File tree

4 files changed

+57
-5
lines changed

4 files changed

+57
-5
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- `ct.floordiv()` and the `//` operator now support floating-point operands.

src/cuda/tile/_ir/ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,11 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
944944
return bc.encode_DivIOp(ctx.builder, res_typeid, lhs, rhs,
945945
signedness=datatype.get_signedness(dtype),
946946
rounding=bc.RoundingMode.NEGATIVE_INF)
947+
case "floordiv", "float":
948+
quotient = bc.encode_DivFOp(ctx.builder, res_typeid, lhs, rhs,
949+
rounding_mode=rounding_mode,
950+
flush_to_zero=self.flush_to_zero)
951+
return bc.encode_FloorOp(ctx.builder, res_typeid, quotient)
947952
case "cdiv", "int":
948953
return bc.encode_DivIOp(ctx.builder, res_typeid, lhs, rhs,
949954
signedness=datatype.get_signedness(dtype),

src/cuda/tile/_stub.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1616,9 +1616,48 @@ def truediv(x, y, /, *, rounding_mode: Optional[RoundingMode] = None,
16161616
pass
16171617

16181618

1619-
@_doc_binary_op('//')
16201619
@function
16211620
def floordiv(x, y, /) -> TileOrScalar:
1621+
"""Elementwise floordiv on two tiles.
1622+
1623+
Can also use builtin operation ``x // y``.
1624+
1625+
Supports both integer and floating-point operands. For float inputs,
1626+
the result is ``floor(x / y)`` as a float (e.g. ``5.5 // 2.2 == 2.0``).
1627+
1628+
Args:
1629+
x (Tile): LHS tile.
1630+
y (Tile): RHS tile.
1631+
1632+
The ``shape`` of ``x`` and ``y`` will be broadcasted and
1633+
``dtype`` promoted to common dtype.
1634+
1635+
Returns:
1636+
Tile:
1637+
1638+
Examples:
1639+
1640+
>>> # integer tile and tile
1641+
>>> tx = ct.full((2, 4), 7, dtype=ct.int32)
1642+
>>> ty = ct.full((2, 4), 3, dtype=ct.int32)
1643+
>>> tz = ct.floordiv(tx, ty)
1644+
1645+
>>> # Can also use the builtin op
1646+
>>> tz = tx // ty
1647+
1648+
>>> # float tile and tile
1649+
>>> tx = ct.full((2, 4), 5.5, dtype=ct.float32)
1650+
>>> ty = ct.full((2, 4), 2.2, dtype=ct.float32)
1651+
>>> tz = tx // ty # result is ct.float32 with value 2.0
1652+
1653+
>>> # tile and scalar
1654+
>>> tx = ct.full((2, 4), 7, dtype=ct.int32)
1655+
>>> y = 2
1656+
>>> tz = tx // y
1657+
1658+
>>> # scalar and scalar
1659+
>>> z = 7 // 2
1660+
"""
16221661
pass
16231662

16241663

test/test_binary_elementwise.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,14 +498,16 @@ def test_array_scalar_div(shape, tile, int_dtype, tmp_path, op_symbol, ref_impl,
498498
@pytest.mark.parametrize("op_symbol, ref_impl", [
499499
("/", lambda x, y: x / y),
500500
("ct.truediv", lambda x, y: x / y),
501+
("//", lambda x, y: x // y),
502+
("ct.floordiv", lambda x, y: x // y),
501503
])
502-
def test_array_scalar_truediv_float(shape, tile, float_dtype, tmp_path, op_symbol, ref_impl):
504+
def test_array_scalar_div_float(shape, tile, float_dtype, tmp_path, op_symbol, ref_impl):
503505
x = make_tensor(shape, dtype=float_dtype, device='cuda')
504506
y = 23.0
505507
res_dtype = torch.promote_types(x.dtype, torch.float32)
506508
ref = ref_impl(x.to(res_dtype), y)
507509
z = torch.zeros_like(ref)
508-
kernel = array_scalar_kernel('truediv',
510+
kernel = array_scalar_kernel('div_float',
509511
f'tz = {op_symbol}(tx, y)' if op_symbol.startswith("ct.") else
510512
f'tz = tx {op_symbol} y',
511513
tmp_path)
@@ -540,17 +542,19 @@ def test_array_div(shape, tile, x_dtype, y_dtype, tmp_path, op_symbol, ref_impl,
540542
@pytest.mark.parametrize("op_symbol, ref_impl", [
541543
("/", lambda x, y: x / y),
542544
("ct.truediv", lambda x, y: x / y),
545+
("//", lambda x, y: torch.floor(x / y)),
546+
("ct.floordiv", lambda x, y: torch.floor(x / y)),
543547
])
544548
@pytest.mark.parametrize("x_dtype", float_dtypes, ids=dtype_id)
545549
@pytest.mark.parametrize("y_dtype", float_dtypes, ids=dtype_id)
546-
def test_array_truediv_float(shape, tile, x_dtype, y_dtype, tmp_path, op_symbol, ref_impl):
550+
def test_array_div_float(shape, tile, x_dtype, y_dtype, tmp_path, op_symbol, ref_impl):
547551
should_raise = {x_dtype, y_dtype} == {torch.float16, torch.bfloat16}
548552
x = (torch.rand(*shape, device="cuda") * 100).to(dtype=x_dtype)
549553
y = (torch.rand(*shape, device="cuda") * 100 + 1).to(dtype=y_dtype)
550554
result_type = torch.promote_types(x.dtype, y.dtype)
551555
z = torch.zeros_like(x).to(result_type)
552556
ref = ref_impl(x, y).to(result_type)
553-
kernel = array_kernel('truediv',
557+
kernel = array_kernel('div_float',
554558
f"tz = {op_symbol}(tx, ty)" if op_symbol.startswith("ct.") else
555559
f"tz = tx {op_symbol} ty",
556560
tmp_path)

0 commit comments

Comments
 (0)