Skip to content

Commit b9d077f

Browse files
committed
Use TensorLikeTy in static_eval() etc.
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 6853570 commit b9d077f

5 files changed

Lines changed: 261 additions & 225 deletions

File tree

experimental/cuda-lang/src/cuda/lang/_ir/ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,6 @@
4040
)
4141
from cuda.tile._ir.ops import (
4242
tile_impl_registry,
43-
bind_method,
44-
build_tuple,
4543
Return,
4644
return_,
4745
Assign,
@@ -63,6 +61,8 @@
6361
astype,
6462
)
6563
from cuda.tile._ir.core_ops import (
64+
bind_method,
65+
build_tuple,
6666
loosely_typed_const,
6767
strictly_typed_const,
6868
)

src/cuda/tile/_ir/core_ops.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,34 @@
11
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
4-
4+
import dataclasses
55
from dataclasses import dataclass
6+
from types import MethodType, FunctionType, BuiltinFunctionType
67
from typing import Any, Optional
78

89
from typing_extensions import override
910

1011
import cuda.tile._bytecode as bc
11-
from cuda.tile._ir.ir import Operation, attribute, Var, Builder
12-
from cuda.tile._ir.type import Type, DTypeSpec, TensorLikeTy
12+
from cuda.tile import TileTypeError
13+
from cuda.tile._datatype import numeric_dtype_category
14+
from cuda.tile._ir import hir_stubs
15+
from cuda.tile._ir.ir import Operation, attribute, Var, Builder, make_aggregate
16+
from cuda.tile._ir.op_impl import ImplRegistry, require_dataclass_type
17+
from cuda.tile._ir.type import Type, DTypeSpec, TensorLikeTy, TupleTy, TupleValue, Symbol, \
18+
DataclassInfo, DataclassTy, DataclassValue, BoundMethodValue, BoundMethodTy
1319
from cuda.tile._ir.typing_support import type_of_constant_python_value, \
14-
loose_type_of_constant_python_value
20+
loose_type_of_constant_python_value, get_dataclass_info, as_third_party_dtype_spec
1521
from cuda.tile._ir2bytecode import BytecodeContext
1622

1723

24+
_registry = ImplRegistry()
25+
impl = _registry.impl
26+
27+
28+
def core_impl_registry() -> ImplRegistry:
29+
return _registry
30+
31+
1832
@dataclass(eq=False)
1933
class TypedConst(Operation, opcode="typed_const"):
2034
value: Any = attribute()
@@ -56,3 +70,83 @@ def _strictly_typed_const_inner(builder: Builder,
5670
# We currently don't have a way to represent an N-dimensional tile constant
5771
ret.set_constant(value)
5872
return ret
73+
74+
75+
@impl(hir_stubs.build_tuple)
76+
def build_tuple(items: tuple[Var, ...]) -> Var:
77+
ty = TupleTy(tuple(x.get_type() for x in items))
78+
loose_ty = TupleTy(tuple(x.get_loose_type() for x in items))
79+
res = make_aggregate(TupleValue(items), ty, loose_ty)
80+
if all(x.is_constant() for x in items):
81+
res.set_constant(tuple(x.get_constant() for x in items))
82+
return res
83+
84+
85+
def build_dataclass_instance(items: tuple[Var, ...], info: DataclassInfo) -> Var:
86+
cls = info.cls
87+
ty = DataclassTy(cls, tuple(x.get_type() for x in items))
88+
loose_ty = DataclassTy(cls, tuple(x.get_loose_type() for x in items))
89+
res = make_aggregate(DataclassValue(items, info), ty, loose_ty)
90+
if all(x.is_constant() for x in items):
91+
const_val = cls(**{name: x.get_constant()
92+
for name, x in zip(info.field_names, items, strict=True)})
93+
res.set_constant(const_val)
94+
return res
95+
96+
97+
@impl(dataclasses.replace)
98+
def dataclasses_replace_impl(obj: Var, changes: dict[str, Var]):
99+
dataclass_ty = require_dataclass_type(obj)
100+
dataclass_val = obj.get_aggregate()
101+
assert isinstance(dataclass_val, DataclassValue)
102+
name2idx = dataclass_val.info.field_name_to_idx
103+
new_items = list(dataclass_val.items)
104+
for name, val in changes.items():
105+
try:
106+
idx = name2idx[name]
107+
except KeyError:
108+
raise TileTypeError(f"Dataclass '{dataclass_ty.cls.__name__}'"
109+
f" has no such field '{name}'")
110+
new_items[idx] = val
111+
return build_dataclass_instance(tuple(new_items), dataclass_val.info)
112+
113+
114+
def bind_method(object: Var, func) -> Var:
115+
agg_value = BoundMethodValue(object)
116+
res_ty = BoundMethodTy(object.get_type(), func)
117+
return make_aggregate(agg_value, res_ty)
118+
119+
120+
def sym2var(x: Any, constant_only: bool = False) -> Var:
121+
# TODO: verify we don't have a stale closure
122+
123+
if isinstance(x, Symbol):
124+
if constant_only:
125+
raise TileTypeError("Cannot create a constant from a symbolic value")
126+
return x._var
127+
128+
if isinstance(x, tuple):
129+
return build_tuple(tuple(sym2var(item, constant_only=constant_only) for item in x))
130+
131+
cls = type(x)
132+
if dataclasses.is_dataclass(cls):
133+
info = get_dataclass_info(cls)
134+
field_vars = tuple(sym2var(getattr(x, f.name), constant_only=constant_only)
135+
for f in dataclasses.fields(cls))
136+
return build_dataclass_instance(field_vars, info)
137+
138+
if isinstance(x, MethodType):
139+
self_var = sym2var(x.__self__, constant_only=constant_only)
140+
if not isinstance(x.__func__, FunctionType | BuiltinFunctionType):
141+
raise TileTypeError(f"Object of type {type(x).__name__}"
142+
f" cannot be used as a function for binding a method")
143+
return bind_method(self_var, x.__func__)
144+
145+
# Transform a third party typed scalar (e.g., np.int16(5)) into a strictly typed constant
146+
dtype_spec = as_third_party_dtype_spec(type(x))
147+
if dtype_spec is not None:
148+
pyval = numeric_dtype_category(dtype_spec.dtype).pytype(x)
149+
ty = Builder.get_current().ir_ctx.typing_hooks.get_tensor_like_type(dtype_spec.dtype, ())
150+
return strictly_typed_const(pyval, ty)
151+
152+
return loosely_typed_const(x)

0 commit comments

Comments
 (0)