Skip to content

Commit 61d6060

Browse files
committed
Introduce BlockRestriction interface for restricting which operations are allowed inside a block
Signed-off-by: Ziheng Deng <zihengd@nvidia.com>
1 parent 6653a56 commit 61d6060

3 files changed

Lines changed: 44 additions & 29 deletions

File tree

src/cuda/tile/_ir/ir.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from cuda.tile._exception import (
2121
TileTypeError, Loc, TileInternalError
2222
)
23-
from .. import TileSyntaxError
2423
from .._context import TileContextConfig
2524
from cuda.tile._bytecode.version import BytecodeVersion
2625

@@ -324,6 +323,14 @@ class MemoryEffect(enum.IntEnum):
324323
STORE = 2
325324

326325

326+
class BlockRestriction:
327+
"""Interface for restricting which operations are allowed inside a block."""
328+
329+
def validate_operation(self, op_class: type) -> None:
330+
"""Raise if the given operation class is not allowed. No restriction by default."""
331+
return
332+
333+
327334
class Mapper:
328335
def __init__(self, ctx: IRContext, preserve_vars: bool = False):
329336
self._ctx = ctx
@@ -396,28 +403,23 @@ def finalize_loopvar_type(self, body_var: Var):
396403

397404

398405
class Builder:
399-
def __init__(self, ctx: IRContext, loc: Loc, reduction_body: bool = False,
400-
scan_body: bool = False):
406+
def __init__(self, ctx: IRContext, loc: Loc,
407+
block_restriction: Optional[BlockRestriction] = None):
401408
self.ir_ctx = ctx
402409
self.is_terminated = False
403410
self._loc = loc
404411
self._ops = []
405412
self._entered = False
406413
self._prev_builder = None
407414
self._var_map: Dict[str, Var] = dict()
408-
self.reduction_body = reduction_body
409-
self.scan_body = scan_body
415+
self.block_restriction = block_restriction
410416

411417
def add_operation(self, op_class,
412418
result_ty: Type | None | Tuple[Type | None, ...],
413419
attrs_and_operands,
414420
result: Var | Sequence[Var] | None = None) -> Var | Tuple[Var, ...]:
415-
if (self.reduction_body or self.scan_body) and op_class.memory_effect != MemoryEffect.NONE:
416-
if self.reduction_body:
417-
msg = "Operations with memory effects are not supported inside reduction body"
418-
else:
419-
msg = "Operations with memory effects are not supported inside scan body"
420-
raise TileSyntaxError(msg)
421+
if self.block_restriction is not None:
422+
self.block_restriction.validate_operation(op_class)
421423

422424
assert not self.is_terminated
423425
force_type = False
@@ -512,12 +514,11 @@ def __exit__(self, exc_type, exc_val, exc_tb):
512514

513515

514516
@contextmanager
515-
def enter_nested_block(loc: Loc, reduction_body: bool = False, scan_body: bool = False):
517+
def enter_nested_block(loc: Loc, block_restriction: Optional[BlockRestriction] = None):
516518
prev_builder = Builder.get_current()
517519
block = Block(prev_builder.ir_ctx, loc=loc)
518520
with Builder(prev_builder.ir_ctx, loc,
519-
reduction_body=reduction_body or prev_builder.reduction_body,
520-
scan_body=scan_body or prev_builder.scan_body) as builder:
521+
block_restriction=block_restriction or prev_builder.block_restriction) as builder:
521522
yield block
522523
block.extend(builder.ops)
523524

src/cuda/tile/_ir/ops.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
add_operation, Builder,
2424
enter_nested_block, nested_block, PhiState, LoopVarState,
2525
TupleValue, make_aggregate, RangeValue, BoundMethodValue, ArrayValue, ConstantState,
26-
ListValue, ClosureValue, MemoryEffect, attribute, operand,
26+
ListValue, ClosureValue, MemoryEffect, attribute, operand, BlockRestriction
2727
)
2828
from cuda.tile._ir.load_store_impl import (
2929
check_load_store_hints,
@@ -306,10 +306,8 @@ async def loop_impl(body: hir.Block, iterable: Var):
306306

307307
# Do this check at the end because this may be an automatically inserted loop
308308
# around the helper function's body.
309-
if builder.reduction_body:
310-
raise TileSyntaxError("Loops inside reduction body are not supported")
311-
if builder.scan_body:
312-
raise TileSyntaxError("Loops inside scan body are not supported")
309+
if builder.block_restriction is not None:
310+
builder.block_restriction.validate_operation(Loop)
313311

314312

315313
def _have_nested_jump(calls: Sequence[hir.Call]) -> bool:
@@ -349,6 +347,26 @@ def _to_string_rhs(self) -> str:
349347
return f"if(cond={self.cond})"
350348

351349

350+
@dataclass
351+
class ReduceScanRestriction(BlockRestriction):
352+
"""Restriction for reduction/scan body blocks: no memory effects, loops, or branching."""
353+
354+
kind: Literal["reduction", "scan"]
355+
356+
def validate_operation(self, op_class: type) -> None:
357+
if getattr(op_class, "memory_effect", MemoryEffect.NONE) != MemoryEffect.NONE:
358+
raise TileSyntaxError(
359+
f"Operations with memory effects are not supported inside {self.kind} body"
360+
)
361+
if op_class is Loop:
362+
raise TileSyntaxError(f"Loops inside {self.kind} body are not supported")
363+
if op_class is IfElse:
364+
raise TileSyntaxError(
365+
f"Branching inside {self.kind} body is not supported. "
366+
f"Consider ct.where() as a workaround."
367+
)
368+
369+
352370
async def _flatten_branch(branch: hir.Block) -> Var | None:
353371
from .._passes.hir2ir import dispatch_hir_block
354372
info = ControlFlowInfo((), flatten=True)
@@ -373,12 +391,8 @@ async def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block)
373391
return await _flatten_branch(branch_taken)
374392

375393
builder = Builder.get_current()
376-
if builder.reduction_body:
377-
raise TileSyntaxError("Branching inside reduction body is not supported."
378-
" Consider ct.where() as a workaround.")
379-
if builder.scan_body:
380-
raise TileSyntaxError("Branching inside scan body is not supported."
381-
" Consider ct.where() as a workaround.")
394+
if builder.block_restriction is not None:
395+
builder.block_restriction.validate_operation(IfElse)
382396

383397
# Get the total number of results by adding the number of stored variables.
384398
# Note: we sort the stored variable indices to make the order deterministic.
@@ -2889,7 +2903,7 @@ async def _get_reduce_scan_body_block(
28892903
"""Build body block for reduce/scan. Caller passes result_shape; returns
28902904
(body_block, result_types)."""
28912905
builder = Builder.get_current()
2892-
if builder.reduction_body or builder.scan_body:
2906+
if isinstance(builder.block_restriction, ReduceScanRestriction):
28932907
raise TileSyntaxError("Nested scan/reduction is not supported")
28942908

28952909
block_params = []
@@ -2913,8 +2927,7 @@ async def _get_reduce_scan_body_block(
29132927

29142928
with enter_nested_block(
29152929
builder.loc,
2916-
reduction_body=(op_name == "reduction"),
2917-
scan_body=(op_name == "scan")) as body_block:
2930+
block_restriction=ReduceScanRestriction(op_name)) as body_block:
29182931
body_block.params = tuple(block_params)
29192932
body_results = await body(tuple(lhs_vars), tuple(rhs_vars))
29202933
for body_res, x in zip(body_results, xs, strict=True):

test/test_scan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ def test_custom_scan_ifelse_not_supported():
297297
def kernel(x, y):
298298
def f(a, b):
299299
if ct.bid(0) == 0:
300-
return a + b
300+
# In case of type mismatch, compiler will complain about nested branching
301+
return a + b.type()
301302
else:
302303
return (a + b) % 5
303304

0 commit comments

Comments
 (0)