Skip to content

Commit bba2494

Browse files
committed
Limit the use of typeof_pyval()
This should be a non-functional change. Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent cba914c commit bba2494

3 files changed

Lines changed: 29 additions & 39 deletions

File tree

src/cuda/tile/_ir/ops.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -617,18 +617,6 @@ def return_(value: Var | None):
617617
add_operation_variadic(Return, ())
618618

619619

620-
def _check_value_numeric_type(value: Any, dtype: DType) -> None:
621-
value_type = typeof_pyval(value)
622-
if datatype.is_arithmetic(value_type):
623-
if not datatype.is_arithmetic(dtype):
624-
raise TileTypeError(f"Expect \"value\" to be a non-numeric dtype {dtype}, "
625-
f"got numeric dtype {value_type}")
626-
# TODO: Both are numeric types, check the data range after ir dtype supports it.
627-
else:
628-
if value_type != dtype:
629-
raise TileTypeError(f"Expect \"value\" to be a {dtype}, got {value_type}")
630-
631-
632620
@dataclass(eq=False)
633621
class TypedConst(Operation, opcode="typed_const"):
634622
value: Any = attribute()

src/cuda/tile/_ir/ops_utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from .ir import Operation, Builder
2020
from .type import TileTy, LooselyTypedScalar
21-
from .typing_support import typeof_pyval
21+
from .typing_support import dtype_of_constant_scalar
2222
from .._datatype import DType, _DTypePromotionImpl, NumericDTypeCategory, PointerInfo
2323

2424

@@ -206,7 +206,7 @@ def memory_order_has_release(memory_order: MemoryOrder):
206206

207207
def get_dtype(ty: TileTy | LooselyTypedScalar) -> datatype.DType:
208208
if isinstance(ty, LooselyTypedScalar):
209-
ty = typeof_pyval(ty.value)
209+
return dtype_of_constant_scalar(ty.value)
210210
assert isinstance(ty, TileTy)
211211
return ty.dtype
212212

@@ -241,9 +241,7 @@ class CompareOrdering(Enum):
241241
def _promote_dtype_and_loosely_typed_constant(dtype: DType,
242242
loose_const: Any,
243243
force_float: bool) -> DType:
244-
loose_ty = typeof_pyval(loose_const)
245-
assert isinstance(loose_ty, TileTy) and loose_ty.ndim == 0
246-
loose_dtype = loose_ty.dtype
244+
loose_dtype = dtype_of_constant_scalar(loose_const)
247245

248246
cat = datatype.numeric_dtype_category(dtype)
249247
if cat == NumericDTypeCategory.RestrictedFloat:
@@ -272,11 +270,9 @@ def promote_dtypes(t1: DType | LooselyTypedScalar,
272270
force_float: bool = False) -> DType:
273271
match t1, t2:
274272
case LooselyTypedScalar(val1), LooselyTypedScalar(val2):
275-
type1 = typeof_pyval(val1)
276-
assert isinstance(type1, TileTy)
277-
type2 = typeof_pyval(val2)
278-
assert isinstance(type2, TileTy)
279-
return _DTypePromotionImpl.promote_dtypes(type1.dtype, type2.dtype, force_float)
273+
dtype1 = dtype_of_constant_scalar(val1)
274+
dtype2 = dtype_of_constant_scalar(val2)
275+
return _DTypePromotionImpl.promote_dtypes(dtype1, dtype2, force_float)
280276
case LooselyTypedScalar(val), dtype:
281277
return _promote_dtype_and_loosely_typed_constant(dtype, val, force_float)
282278
case dtype, LooselyTypedScalar(val):
@@ -329,8 +325,8 @@ def check_implicit_cast(src_ty: TileTy | LooselyTypedScalar, target_dtype: DType
329325
raise TileValueError(f"cannot implicitly cast {src_ty.value}"
330326
f" to a non-numeric dtype {target_dtype}")
331327

332-
cocnrete_ty = typeof_pyval(src_ty.value)
333-
src_cat = datatype.numeric_dtype_category(cocnrete_ty.dtype)
328+
concrete_dtype = dtype_of_constant_scalar(src_ty.value)
329+
src_cat = datatype.numeric_dtype_category(concrete_dtype)
334330
dst_cat = datatype.numeric_dtype_category(target_dtype)
335331
if dst_cat == NumericDTypeCategory.Boolean:
336332
if src_cat not in (NumericDTypeCategory.Boolean, NumericDTypeCategory.Integral) \

src/cuda/tile/_ir/typing_support.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -139,28 +139,34 @@ def is_supported_builtin_func(x: Any) -> bool:
139139
return _safe_get(BUILTIN_FUNCS, x) is not None or getattr(x, '_cutile_is_builtin', False)
140140

141141

142-
def typeof_pyval(val) -> Type:
143-
if val is None:
144-
return NONE
145-
if (t := _safe_get(_dtype_registry, type(val))):
146-
return TileTy(t.dtype)
142+
def dtype_of_constant_scalar(val: bool | int | float) -> DType:
147143
if isinstance(val, bool):
148-
return TileTy(datatype.bool_)
149-
if isinstance(val, Enum):
150-
return EnumTy(type(val))
151-
if isinstance(val, int):
144+
return datatype.bool_
145+
elif isinstance(val, int):
152146
if -2**31 <= val < 2**31:
153-
dtype = datatype.int32
147+
return datatype.int32
154148
elif -2**63 <= val < 2**63:
155-
dtype = datatype.int64
149+
return datatype.int64
156150
elif 0 <= val < 2**64:
157-
dtype = datatype.uint64
151+
return datatype.uint64
158152
else:
159153
# FIXME: delay the error and allow arbitrary-precision intermediate constant values
160154
raise TileValueError(f"Constant {val} is out of range of any supported integer type")
161-
return TileTy(dtype)
162-
if isinstance(val, float):
163-
return TileTy(datatype.default_float_type)
155+
elif isinstance(val, float):
156+
return datatype.default_float_type
157+
else:
158+
raise TypeError(f'Python value {val} of type {type(val)} is not supported.')
159+
160+
161+
def typeof_pyval(val) -> Type:
162+
if val is None:
163+
return NONE
164+
if (t := _safe_get(_dtype_registry, type(val))):
165+
return TileTy(t.dtype)
166+
if isinstance(val, bool | int | float):
167+
return TileTy(dtype_of_constant_scalar(val))
168+
if isinstance(val, Enum):
169+
return EnumTy(type(val))
164170
if isinstance(val, str):
165171
return StringTy(val)
166172
if isinstance(val, tuple):

0 commit comments

Comments
 (0)