55import functools
66import operator
77from dataclasses import dataclass
8- from types import MethodType , FunctionType , BuiltinFunctionType
8+ from types import MethodType , FunctionType , BuiltinFunctionType , MappingProxyType
99from typing import Any , Optional
1010
1111from typing_extensions import override
2727 DataclassInfo , DataclassTy , DataclassValue , BoundMethodValue , BoundMethodTy , InvalidType , \
2828 ContextManagerTy , ContextManagerLifecycle , LiveCapturedScope , ClosureTy , ClosureValue , \
2929 RangeIterType , RangeValue , TypeTy , ModuleTy , NONE , SliceType , StringTy , FormattedStringTy , \
30- StringFormat , FormattedStringValue , FormattedPiece
30+ StringFormat , FormattedStringValue , FormattedPiece , DictTy , DictValue
3131from cuda .tile ._ir .typing_support import type_of_constant_python_value , \
3232 loose_type_of_constant_python_value , get_dataclass_info , as_third_party_dtype_spec
3333from cuda .tile ._ir2bytecode import BytecodeContext
@@ -65,6 +65,7 @@ def core_impl_registry() -> ImplRegistry:
6565@overload_dispatcher (operator .lshift , fixed_args = ["<<" ])
6666@overload_dispatcher (operator .rshift , fixed_args = [">>" ])
6767@overload_dispatcher (operator .matmul , fixed_args = ["@" ])
68+ @overload_dispatcher (hir_stubs .is_contained_in , fixed_args = ["'in'" ])
6869@overload_dispatcher (min , fixed_args = ["min" ])
6970@overload_dispatcher (max , fixed_args = ["max" ])
7071def binop_overload_dispatcher (name : str , x : Var , y : Var ):
@@ -76,6 +77,13 @@ def binop_overload_dispatcher(name: str, x: Var, y: Var):
7677 raise TileTypeError (f"Unsupported operand types for { name } : { x_ty } and { y_ty } " )
7778
7879
80+ @impl (hir_stubs .is_not_contained_in )
81+ async def is_not_contained_in_impl (x : Var , y : Var ):
82+ from .._passes .hir2ir import call_function
83+ contained = await call_function (hir_stubs .is_contained_in , x , y )
84+ return await call_function (operator .not_ , contained )
85+
86+
7987def comparison_operator_impl (registry : ImplRegistry , lhs_ty : type [Type ], rhs_ty : type [Type ]):
8088 def decorate (func ):
8189 for name in ("eq" , "ne" , "lt" , "le" , "gt" , "ge" ):
@@ -346,6 +354,71 @@ def len_tuple_impl(x: Var[TupleTy]) -> Var:
346354 return loosely_typed_const (len (x .get_type ()))
347355
348356
357+ # ===========================================================================================
358+ # Dictionary
359+ # ===========================================================================================
360+
361+ def build_dict (keys : tuple [str , ...], values : tuple [Var , ...]) -> Var :
362+ keys = tuple (keys )
363+ values = tuple (values )
364+ assert len (keys ) == len (values )
365+
366+ ty = DictTy (keys , tuple (x .get_type () for x in values ))
367+ loose_ty = DictTy (keys , tuple (x .get_loose_type () for x in values ))
368+ res = make_aggregate (DictValue (values ), ty , loose_ty )
369+ if all (x .is_constant () for x in values ):
370+ items = [(k , v .get_constant ()) for k , v in zip (keys , values , strict = True )]
371+ res .set_constant (MappingProxyType (dict (items )))
372+ return res
373+
374+
375+ def _find_dict_key_index (key : Var , dict_ty : DictTy ) -> int | None :
376+ key_ty = key .get_type ()
377+ if not isinstance (key_ty , StringTy ):
378+ # Python would happily report that the key is not found when a "wrong" key type is passed,
379+ # but we can add a stronger check here.
380+ raise TileTypeError (f"Dictionary keys must be strings, not { key_ty } " )
381+
382+ return dict_ty .keys .index (key_ty .value ) if key_ty .value in dict_ty .keys else None
383+
384+
385+ @impl (hir_stubs .is_contained_in , overload = (WILDCARD , DictTy ))
386+ async def is_contained_in_dict_impl (x : Var , y : Var [DictTy ]):
387+ return loosely_typed_const (_find_dict_key_index (x , y .get_type ()) is not None )
388+
389+
390+ @impl (getattr , overload = (DictTy , "get" ))
391+ def getattr_dict_method (object : Var , name : Var ):
392+ name = require_constant_str (name )
393+ unbound_func = getattr (dict , name )
394+ return bind_method (object , unbound_func )
395+
396+
397+ @impl (operator .getitem , overload = (DictTy , WILDCARD ))
398+ def getitem_dict_impl (object : Var [DictTy ], key : Var ):
399+ idx = _find_dict_key_index (key , object .get_type ())
400+ if idx is None :
401+ raise TileTypeError (f"Key '{ key .get_type ().value } ' not found in dictionary" )
402+ dict_value = object .get_aggregate ()
403+ assert isinstance (dict_value , DictValue )
404+ return dict_value .values [idx ]
405+
406+
407+ @impl (dict .get )
408+ def dict_get_impl (self : Var , key : Var , default : Var ):
409+ dict_ty = self .get_type ()
410+ if not isinstance (dict_ty , DictTy ):
411+ raise TileTypeError (f"dict.get() expects a dictionary as `self`, got { dict_ty } " )
412+
413+ idx = _find_dict_key_index (key , dict_ty )
414+ if idx is None :
415+ return default
416+
417+ dict_value = self .get_aggregate ()
418+ assert isinstance (dict_value , DictValue )
419+ return dict_value .values [idx ]
420+
421+
349422# ===========================================================================================
350423# Dataclass
351424# ===========================================================================================
0 commit comments