2323 add_operation , Builder ,
2424 enter_nested_block , nested_block , PhiState , LoopVarState ,
2525 TupleValue , make_aggregate , RangeValue , BoundMethodValue , ArrayValue , ConstantState ,
26- ListValue , ClosureValue , MemoryEffect , attribute , operand ,
26+ ListValue , ClosureValue , MemoryEffect , attribute , operand , BlockRestriction
2727)
2828from cuda .tile ._ir .load_store_impl import (
2929 check_load_store_hints ,
@@ -306,10 +306,8 @@ async def loop_impl(body: hir.Block, iterable: Var):
306306
307307 # Do this check at the end because this may be an automatically inserted loop
308308 # around the helper function's body.
309- if builder .reduction_body :
310- raise TileSyntaxError ("Loops inside reduction body are not supported" )
311- if builder .scan_body :
312- raise TileSyntaxError ("Loops inside scan body are not supported" )
309+ if builder .block_restriction is not None :
310+ builder .block_restriction .validate_operation (Loop )
313311
314312
315313def _have_nested_jump (calls : Sequence [hir .Call ]) -> bool :
@@ -349,6 +347,26 @@ def _to_string_rhs(self) -> str:
349347 return f"if(cond={ self .cond } )"
350348
351349
350+ @dataclass
351+ class ReduceScanRestriction (BlockRestriction ):
352+ """Restriction for reduction/scan body blocks: no memory effects, loops, or branching."""
353+
354+ kind : Literal ["reduction" , "scan" ]
355+
356+ def validate_operation (self , op_class : type ) -> None :
357+ if getattr (op_class , "memory_effect" , MemoryEffect .NONE ) != MemoryEffect .NONE :
358+ raise TileSyntaxError (
359+ f"Operations with memory effects are not supported inside { self .kind } body"
360+ )
361+ if op_class is Loop :
362+ raise TileSyntaxError (f"Loops inside { self .kind } body are not supported" )
363+ if op_class is IfElse :
364+ raise TileSyntaxError (
365+ f"Branching inside { self .kind } body is not supported. "
366+ f"Consider ct.where() as a workaround."
367+ )
368+
369+
352370async def _flatten_branch (branch : hir .Block ) -> Var | None :
353371 from .._passes .hir2ir import dispatch_hir_block
354372 info = ControlFlowInfo ((), flatten = True )
@@ -373,12 +391,8 @@ async def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block)
373391 return await _flatten_branch (branch_taken )
374392
375393 builder = Builder .get_current ()
376- if builder .reduction_body :
377- raise TileSyntaxError ("Branching inside reduction body is not supported."
378- " Consider ct.where() as a workaround." )
379- if builder .scan_body :
380- raise TileSyntaxError ("Branching inside scan body is not supported."
381- " Consider ct.where() as a workaround." )
394+ if builder .block_restriction is not None :
395+ builder .block_restriction .validate_operation (IfElse )
382396
383397 # Get the total number of results by adding the number of stored variables.
384398 # Note: we sort the stored variable indices to make the order deterministic.
@@ -2889,7 +2903,7 @@ async def _get_reduce_scan_body_block(
28892903 """Build body block for reduce/scan. Caller passes result_shape; returns
28902904 (body_block, result_types)."""
28912905 builder = Builder .get_current ()
2892- if builder .reduction_body or builder . scan_body :
2906+ if isinstance ( builder .block_restriction , ReduceScanRestriction ) :
28932907 raise TileSyntaxError ("Nested scan/reduction is not supported" )
28942908
28952909 block_params = []
@@ -2913,8 +2927,7 @@ async def _get_reduce_scan_body_block(
29132927
29142928 with enter_nested_block (
29152929 builder .loc ,
2916- reduction_body = (op_name == "reduction" ),
2917- scan_body = (op_name == "scan" )) as body_block :
2930+ block_restriction = ReduceScanRestriction (op_name )) as body_block :
29182931 body_block .params = tuple (block_params )
29192932 body_results = await body (tuple (lhs_vars ), tuple (rhs_vars ))
29202933 for body_res , x in zip (body_results , xs , strict = True ):
0 commit comments