|
1 | 1 | # SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | | - |
| 4 | +import dataclasses |
5 | 5 | from dataclasses import dataclass |
| 6 | +from types import MethodType, FunctionType, BuiltinFunctionType |
6 | 7 | from typing import Any, Optional |
7 | 8 |
|
8 | 9 | from typing_extensions import override |
9 | 10 |
|
10 | 11 | import cuda.tile._bytecode as bc |
11 | | -from cuda.tile._ir.ir import Operation, attribute, Var, Builder |
12 | | -from cuda.tile._ir.type import Type, DTypeSpec, TensorLikeTy |
| 12 | +from cuda.tile import TileTypeError |
| 13 | +from cuda.tile._datatype import numeric_dtype_category |
| 14 | +from cuda.tile._ir import hir_stubs |
| 15 | +from cuda.tile._ir.ir import Operation, attribute, Var, Builder, make_aggregate |
| 16 | +from cuda.tile._ir.op_impl import ImplRegistry, require_dataclass_type |
| 17 | +from cuda.tile._ir.type import Type, DTypeSpec, TensorLikeTy, TupleTy, TupleValue, Symbol, \ |
| 18 | + DataclassInfo, DataclassTy, DataclassValue, BoundMethodValue, BoundMethodTy |
13 | 19 | from cuda.tile._ir.typing_support import type_of_constant_python_value, \ |
14 | | - loose_type_of_constant_python_value |
| 20 | + loose_type_of_constant_python_value, get_dataclass_info, as_third_party_dtype_spec |
15 | 21 | from cuda.tile._ir2bytecode import BytecodeContext |
16 | 22 |
|
17 | 23 |
|
| 24 | +_registry = ImplRegistry() |
| 25 | +impl = _registry.impl |
| 26 | + |
| 27 | + |
| 28 | +def core_impl_registry() -> ImplRegistry: |
| 29 | + return _registry |
| 30 | + |
| 31 | + |
18 | 32 | @dataclass(eq=False) |
19 | 33 | class TypedConst(Operation, opcode="typed_const"): |
20 | 34 | value: Any = attribute() |
@@ -56,3 +70,83 @@ def _strictly_typed_const_inner(builder: Builder, |
56 | 70 | # We currently don't have a way to represent an N-dimensional tile constant |
57 | 71 | ret.set_constant(value) |
58 | 72 | return ret |
| 73 | + |
| 74 | + |
| 75 | +@impl(hir_stubs.build_tuple) |
| 76 | +def build_tuple(items: tuple[Var, ...]) -> Var: |
| 77 | + ty = TupleTy(tuple(x.get_type() for x in items)) |
| 78 | + loose_ty = TupleTy(tuple(x.get_loose_type() for x in items)) |
| 79 | + res = make_aggregate(TupleValue(items), ty, loose_ty) |
| 80 | + if all(x.is_constant() for x in items): |
| 81 | + res.set_constant(tuple(x.get_constant() for x in items)) |
| 82 | + return res |
| 83 | + |
| 84 | + |
| 85 | +def build_dataclass_instance(items: tuple[Var, ...], info: DataclassInfo) -> Var: |
| 86 | + cls = info.cls |
| 87 | + ty = DataclassTy(cls, tuple(x.get_type() for x in items)) |
| 88 | + loose_ty = DataclassTy(cls, tuple(x.get_loose_type() for x in items)) |
| 89 | + res = make_aggregate(DataclassValue(items, info), ty, loose_ty) |
| 90 | + if all(x.is_constant() for x in items): |
| 91 | + const_val = cls(**{name: x.get_constant() |
| 92 | + for name, x in zip(info.field_names, items, strict=True)}) |
| 93 | + res.set_constant(const_val) |
| 94 | + return res |
| 95 | + |
| 96 | + |
| 97 | +@impl(dataclasses.replace) |
| 98 | +def dataclasses_replace_impl(obj: Var, changes: dict[str, Var]): |
| 99 | + dataclass_ty = require_dataclass_type(obj) |
| 100 | + dataclass_val = obj.get_aggregate() |
| 101 | + assert isinstance(dataclass_val, DataclassValue) |
| 102 | + name2idx = dataclass_val.info.field_name_to_idx |
| 103 | + new_items = list(dataclass_val.items) |
| 104 | + for name, val in changes.items(): |
| 105 | + try: |
| 106 | + idx = name2idx[name] |
| 107 | + except KeyError: |
| 108 | + raise TileTypeError(f"Dataclass '{dataclass_ty.cls.__name__}'" |
| 109 | + f" has no such field '{name}'") |
| 110 | + new_items[idx] = val |
| 111 | + return build_dataclass_instance(tuple(new_items), dataclass_val.info) |
| 112 | + |
| 113 | + |
| 114 | +def bind_method(object: Var, func) -> Var: |
| 115 | + agg_value = BoundMethodValue(object) |
| 116 | + res_ty = BoundMethodTy(object.get_type(), func) |
| 117 | + return make_aggregate(agg_value, res_ty) |
| 118 | + |
| 119 | + |
| 120 | +def sym2var(x: Any, constant_only: bool = False) -> Var: |
| 121 | + # TODO: verify we don't have a stale closure |
| 122 | + |
| 123 | + if isinstance(x, Symbol): |
| 124 | + if constant_only: |
| 125 | + raise TileTypeError("Cannot create a constant from a symbolic value") |
| 126 | + return x._var |
| 127 | + |
| 128 | + if isinstance(x, tuple): |
| 129 | + return build_tuple(tuple(sym2var(item, constant_only=constant_only) for item in x)) |
| 130 | + |
| 131 | + cls = type(x) |
| 132 | + if dataclasses.is_dataclass(cls): |
| 133 | + info = get_dataclass_info(cls) |
| 134 | + field_vars = tuple(sym2var(getattr(x, f.name), constant_only=constant_only) |
| 135 | + for f in dataclasses.fields(cls)) |
| 136 | + return build_dataclass_instance(field_vars, info) |
| 137 | + |
| 138 | + if isinstance(x, MethodType): |
| 139 | + self_var = sym2var(x.__self__, constant_only=constant_only) |
| 140 | + if not isinstance(x.__func__, FunctionType | BuiltinFunctionType): |
| 141 | + raise TileTypeError(f"Object of type {type(x).__name__}" |
| 142 | + f" cannot be used as a function for binding a method") |
| 143 | + return bind_method(self_var, x.__func__) |
| 144 | + |
| 145 | + # Transform a third party typed scalar (e.g., np.int16(5)) into a strictly typed constant |
| 146 | + dtype_spec = as_third_party_dtype_spec(type(x)) |
| 147 | + if dtype_spec is not None: |
| 148 | + pyval = numeric_dtype_category(dtype_spec.dtype).pytype(x) |
| 149 | + ty = Builder.get_current().ir_ctx.typing_hooks.get_tensor_like_type(dtype_spec.dtype, ()) |
| 150 | + return strictly_typed_const(pyval, ty) |
| 151 | + |
| 152 | + return loosely_typed_const(x) |
0 commit comments