Skip to content

Commit 2117029

Browse files
committed
Add support for custom scan via ct.scan
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent 394d302 commit 2117029

11 files changed

Lines changed: 591 additions & 144 deletions

File tree

changelog.d/custom-scan.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Added support for custom scan via `ct.scan()`.

docs/source/operations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ Scan
7878

7979
cumsum
8080
cumprod
81+
scan
8182

8283

8384
Matmul

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
cosh,
9191
cumprod,
9292
cumsum,
93+
scan,
9394
equal,
9495
exp,
9596
exp2,
@@ -224,6 +225,7 @@
224225
"cosh",
225226
"cumprod",
226227
"cumsum",
228+
"scan",
227229
"equal",
228230
"exp",
229231
"exp2",

src/cuda/tile/_ir/ir.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,8 @@ def finalize_loopvar_type(self, body_var: Var):
414414

415415

416416
class Builder:
417-
def __init__(self, ctx: IRContext, loc: Loc, reduction_body: bool = False):
417+
def __init__(self, ctx: IRContext, loc: Loc, reduction_body: bool = False,
418+
scan_body: bool = False):
418419
self.ir_ctx = ctx
419420
self.is_terminated = False
420421
self._loc = loc
@@ -423,14 +424,18 @@ def __init__(self, ctx: IRContext, loc: Loc, reduction_body: bool = False):
423424
self._prev_builder = None
424425
self._var_map: Dict[str, Var] = dict()
425426
self.reduction_body = reduction_body
427+
self.scan_body = scan_body
426428

427429
def add_operation(self, op_class,
428430
result_ty: Type | None | Tuple[Type | None, ...],
429431
attrs_and_operands,
430432
result: Var | Sequence[Var] | None = None) -> Var | Tuple[Var, ...]:
431-
if self.reduction_body and op_class.memory_effect != MemoryEffect.NONE:
432-
raise TileSyntaxError("Operations with memory effects are not supported"
433-
" inside reduction body")
433+
if (self.reduction_body or self.scan_body) and op_class.memory_effect != MemoryEffect.NONE:
434+
if self.reduction_body:
435+
msg = "Operations with memory effects are not supported inside reduction body"
436+
else:
437+
msg = "Operations with memory effects are not supported inside scan body"
438+
raise TileSyntaxError(msg)
434439

435440
assert not self.is_terminated
436441
force_type = False
@@ -524,11 +529,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
524529

525530

526531
@contextmanager
527-
def nested_block(loc: Loc, reduction_body: bool = False):
532+
def nested_block(loc: Loc, reduction_body: bool = False, scan_body: bool = False):
528533
prev_builder = Builder.get_current()
529534
block = Block(prev_builder.ir_ctx, loc=loc)
530535
with Builder(prev_builder.ir_ctx, loc,
531-
reduction_body=reduction_body or prev_builder.reduction_body) as builder:
536+
reduction_body=reduction_body or prev_builder.reduction_body,
537+
scan_body=scan_body or prev_builder.scan_body) as builder:
532538
yield block
533539
block.extend(builder.ops)
534540

0 commit comments

Comments
 (0)