1818 List , Optional , Dict , Tuple , Set , Any , TYPE_CHECKING , Sequence , Iterator
1919)
2020from .type import Type , InvalidType
21- from .typing_support import typeof_pyval , get_constant_value , loose_type_of_pyval
2221from cuda .tile ._exception import (
23- TileTypeError ,
24- TileValueError ,
25- Loc , TileInternalError
22+ TileTypeError , Loc , TileInternalError , TileSyntaxError
2623)
2724from .._cext import TileContext
2825
@@ -308,28 +305,122 @@ def finalize_loopvar_type(self, body_var: Var):
308305class LoopInfo :
309306 var_states : tuple [LoopVarState , ...]
310307 is_for_loop : bool
308+ stored_names : tuple [str , ...]
309+ flatten : bool = False
311310
312311
313- def get_innermost_loop () -> LoopInfo | None :
314- return Builder .get_current ().loop_info
312+ @dataclass
313+ class IfElseInfo :
314+ result_phis : tuple [PhiState , ...]
315+ stored_names : tuple [str , ...]
316+ flatten : bool = False
317+ flattened_results : tuple [Var , ...] = ()
318+ have_end_branch : bool = False
315319
316320
317321@contextmanager
318322def nested_block (name : str , loc : Loc , params : Sequence [Var ] = (),
319- loop_info : LoopInfo | None = None ):
323+ loop_info : LoopInfo | None = None ,
324+ if_else_info : IfElseInfo | None = None ):
320325 prev_builder = Builder .get_current ()
321326 block = Block (prev_builder .ir_ctx , params = params , name = name , loc = loc )
322327 new_loop_info = loop_info or prev_builder .loop_info
323- with Builder (prev_builder .ir_ctx , loc , new_loop_info ) as builder :
328+ new_if_else_info = if_else_info or prev_builder .if_else_info
329+ scope = prev_builder .scope
330+ with Builder (prev_builder .ir_ctx , loc , scope , new_loop_info , new_if_else_info ) as builder , \
331+ scope .local .enter_branch ():
324332 yield block
325333 block .extend (builder .ops )
326334
327335
336+ class LocalScope :
337+ def __init__ (self ,
338+ all_locals : Set [str ],
339+ ir_ctx : IRContext ,
340+ parent : Optional ["LocalScope" ] = None ):
341+ self ._all_locals = all_locals
342+ self ._ir_ctx = ir_ctx
343+ self ._map = dict ()
344+ self ._parent = parent
345+
346+ def is_local_name (self , name : str ):
347+ current = self
348+ while current is not None :
349+ if name in current ._all_locals :
350+ return True
351+ current = current ._parent
352+ return False
353+
354+ def redefine (self , name : str , loc : Loc ) -> Var :
355+ var = self ._ir_ctx .make_var (name , loc )
356+ self ._map [name ] = var
357+ return var
358+
359+ def __getitem__ (self , name : str ):
360+ var = self ._lookup (name )
361+ if var is None :
362+ raise TileSyntaxError (f"Undefined variable { name } used" )
363+ return var
364+
365+ def get (self , name : str , loc : Loc ):
366+ var = self ._lookup (name )
367+ if var is None :
368+ return self ._ir_ctx .make_var (name , loc , undefined = True )
369+ else :
370+ return var
371+
372+ def _lookup (self , name : str ) -> Optional [Var ]:
373+ seen = set ()
374+ current = self
375+ while current is not None :
376+ var = current ._map .get (name )
377+ if var is not None :
378+ return var
379+ # Sanity check, should not reach here.
380+ if id (current ) in seen :
381+ raise RuntimeError ("Cycle detected in Scope chain" )
382+ seen .add (id (current ))
383+ current = current ._parent
384+ return None
385+
386+ @contextmanager
387+ def enter_branch (self ):
388+ old = self ._map
389+ self ._map = _OverlayDict (old )
390+ try :
391+ yield
392+ finally :
393+ self ._map = old
394+
395+
396+ class _OverlayDict :
397+ def __init__ (self , orig_dict : dict ):
398+ self ._orig = orig_dict
399+ self ._overlay = dict ()
400+
401+ def get (self , key ):
402+ value = self ._overlay .get (key )
403+ return self ._orig .get (key ) if value is None else value
404+
405+ def __setitem__ (self , key , value ):
406+ self ._overlay [key ] = value
407+
408+
409+ @dataclass
410+ class Scope :
411+ local : LocalScope
412+ frozen_globals : Mapping [str , Any ]
413+
414+
328415class Builder :
329- def __init__ (self , ctx : IRContext , loc : Loc , loop_info : LoopInfo | None = None ):
416+ def __init__ (self , ctx : IRContext , loc : Loc , scope : Scope ,
417+ loop_info : LoopInfo | None = None ,
418+ if_else_info : IfElseInfo | None = None ):
330419 self .ir_ctx = ctx
420+ self .scope = scope
331421 self .is_terminated = False
332422 self .loop_info = loop_info
423+ self .if_else_info = if_else_info
333424 self ._loc = loc
334425 self ._ops = []
335426 self ._entered = False
@@ -363,6 +454,10 @@ def add_operation(self, op_class,
363454 def ops (self ) -> list [Operation ]:
364455 return self ._ops
365456
457+ @property
458+ def loc (self ) -> Loc :
459+ return self ._loc
460+
366461 def append_verbatim (self , op : Operation ):
367462 self ._ops .append (op )
368463
@@ -384,6 +479,24 @@ def change_loc(self, loc: Loc):
384479 finally :
385480 self ._loc = old_loc
386481
482+ @contextmanager
483+ def change_if_else_info (self , new_info : IfElseInfo ):
484+ old = self .if_else_info
485+ self .if_else_info = new_info
486+ try :
487+ yield
488+ finally :
489+ self .if_else_info = old
490+
491+ @contextmanager
492+ def change_loop_info (self , new_info : LoopInfo ):
493+ old = self .loop_info
494+ self .loop_info = new_info
495+ try :
496+ yield
497+ finally :
498+ self .loop_info = old
499+
387500 def __enter__ (self ):
388501 assert not self ._entered
389502 self ._prev_builder = _current_builder .builder
@@ -685,69 +798,8 @@ def __str__(self) -> str:
685798 return self .to_string ()
686799
687800
688- def bind_kernel_arguments (params : Tuple [Var , ...],
689- args : Tuple [Any , ...],
690- constant_args : Set [str ]) -> Tuple [Argument , ...]:
691- # TODO: unify this logic with dispatcher from c extension
692- # Refactor "extract_cuda_args" to return type descriptor
693- # that can be wrapped as IR Type for type inference.
694- if len (args ) != len (params ):
695- msg = f"Expected { len (params )} arguments, got { len (args )} "
696- raise TileValueError (msg )
697-
698- ir_args = []
699- for param , arg_value in zip (params , args ):
700- const_val = None
701- is_const = param .name in constant_args
702- ty = typeof_pyval (arg_value , kernel_arg = not is_const )
703- loose_type = ty
704- if is_const :
705- try :
706- const_val = get_constant_value (arg_value )
707- except TileTypeError :
708- raise TileTypeError (
709- f"Argument { param .name } is a constexpr, "
710- f"but the value is not a supported constant." )
711- loose_type = loose_type_of_pyval (arg_value )
712- ir_args .append (Argument (type = ty ,
713- loose_type = loose_type ,
714- is_const = is_const ,
715- const_value = const_val ))
716- return tuple (ir_args )
717-
718-
801+ @dataclass
719802class Argument :
720- def __init__ (self ,
721- type : Type ,
722- loose_type : Type ,
723- is_const : bool = False ,
724- const_value : Any = None ):
725- self ._type = type
726- self ._loose_type = loose_type
727- self ._is_const = is_const
728- self ._const_value = const_value
729-
730- @property
731- def is_const (self ) -> bool :
732- return self ._is_const
733-
734- @property
735- def const_value (self ) -> Any :
736- return self ._const_value
737-
738- @property
739- def type (self ) -> Type :
740- return self ._type
741-
742- @property
743- def loose_type (self ) -> Type :
744- return self ._loose_type
745-
746- def __eq__ (self , value : object ) -> bool :
747- if not isinstance (value , Argument ):
748- return False
749- return (
750- self .type == value .type and
751- self .is_const == value .is_const and
752- self .const_value == value .const_value
753- )
803+ type : Type
804+ is_const : bool = False
805+ const_value : Any = None
0 commit comments