Skip to content

Commit 24deef9

Browse files
committed
Add Int64 Array Annotation and Scalar Annotation
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent 4daf534 commit 24deef9

15 files changed

Lines changed: 428 additions & 233 deletions

cext/tile_kernel.cpp

Lines changed: 147 additions & 88 deletions
Large diffs are not rendered by default.

changelog.d/int64-annotations.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- New `ct.IndexedWithInt64` annotation for array kernel parameters whose shape
5+
or stride values exceed the range of a 32-bit integer. Arrays without the
6+
annotation continue to use `int32` for shape and stride.
7+
- New `ct.ScalarInt64` annotation that forces a scalar integer kernel parameter
8+
to be inferred as `int64` instead of the default `int32`.

src/cuda/tile/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,13 @@
5555

5656
from cuda.tile._stub import (
5757
Array,
58+
ArrayAnnotation,
5859
Constant,
5960
ConstantAnnotation,
61+
IndexedWithInt64,
62+
ListAnnotation,
6063
Scalar,
64+
ScalarInt64,
6165
Tile,
6266
TiledView,
6367

@@ -204,9 +208,13 @@
204208
"TileValueError",
205209

206210
"Array",
211+
"ArrayAnnotation",
207212
"Constant",
208213
"ConstantAnnotation",
214+
"IndexedWithInt64",
215+
"ListAnnotation",
209216
"Scalar",
217+
"ScalarInt64",
210218
"Tile",
211219
"TiledView",
212220

src/cuda/tile/_annotated_function.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,58 @@
77
from types import FunctionType
88
from typing import (get_origin, get_args, Annotated, Any, Sequence)
99

10-
from cuda.tile._stub import ConstantAnnotation
10+
from cuda.tile._stub import ConstantAnnotation, ArrayAnnotation, ScalarAnnotation, ListAnnotation
11+
from cuda.tile._datatype import int64
1112

1213

1314
@dataclass
1415
class AnnotatedFunction:
1516
pyfunc: FunctionType
1617
pysig: inspect.Signature
1718
constant_parameter_mask: Sequence[bool]
19+
# array index dtype and scalar integer dtype can only be int64 or int32 now.
20+
int64_index_parameter_mask: Sequence[bool]
21+
int64_parameter_mask: Sequence[bool]
1822

1923

2024
def get_annotated_function(pyfunc: FunctionType) -> AnnotatedFunction:
2125
sig = inspect.signature(pyfunc)
2226
constant_parameter_mask = tuple(_has_constant_annotation(param.annotation)
2327
for param in sig.parameters.values())
28+
int64_index_parameter_mask = tuple(_has_int64_index_annotation(param.annotation)
29+
for param in sig.parameters.values())
30+
int64_parameter_mask = tuple(_has_int64_annotation(param.annotation)
31+
for param in sig.parameters.values())
2432
return AnnotatedFunction(pyfunc=pyfunc,
2533
pysig=sig,
26-
constant_parameter_mask=constant_parameter_mask)
34+
constant_parameter_mask=constant_parameter_mask,
35+
int64_index_parameter_mask=int64_index_parameter_mask,
36+
int64_parameter_mask=int64_parameter_mask)
2737

2838

2939
def _has_constant_annotation(annotation: Any) -> bool:
3040
if get_origin(annotation) is Annotated:
3141
_, *metadata = get_args(annotation)
3242
return any(isinstance(m, ConstantAnnotation) for m in metadata)
3343
return False
44+
45+
46+
def _has_int64_index_annotation(annotation: Any) -> bool:
47+
if get_origin(annotation) is Annotated:
48+
_, *metadata = get_args(annotation)
49+
for m in metadata:
50+
if isinstance(m, ArrayAnnotation) and m.index_dtype is int64:
51+
return True
52+
if (isinstance(m, ListAnnotation)
53+
and isinstance(m.element, ArrayAnnotation)
54+
and m.element.index_dtype is int64):
55+
return True
56+
return False
57+
return False
58+
59+
60+
def _has_int64_annotation(annotation: Any) -> bool:
61+
if get_origin(annotation) is Annotated:
62+
_, *metadata = get_args(annotation)
63+
return any(isinstance(m, ScalarAnnotation) and m.dtype is int64 for m in metadata)
64+
return False

src/cuda/tile/_compile.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from typing import Optional, Sequence
2424
import zipfile
2525

26-
from cuda.tile import int32
2726
from cuda.tile._annotated_function import AnnotatedFunction, get_annotated_function
2827
from cuda.tile._bytecode.version import BytecodeVersion
2928
from cuda.tile._cext import get_compute_capability, TileContext, default_tile_context
@@ -157,9 +156,6 @@ def _create_kernel_parameters(parameter_constraints: Sequence[ParameterConstrain
157156

158157

159158
def _get_array_ty(param: ArrayConstraint):
160-
if param.index_dtype != int32:
161-
raise NotImplementedError("Only int32 is currently supported as array index type")
162-
163159
for static_stride, bound in zip(param.stride_constant, param.stride_lower_bound_incl,
164160
strict=True):
165161
if static_stride is not None:
@@ -170,7 +166,8 @@ def _get_array_ty(param: ArrayConstraint):
170166

171167
return ArrayTy(make_tile_ty(param.dtype, ()),
172168
shape=(None,) * param.ndim,
173-
strides=param.stride_constant)
169+
strides=param.stride_constant,
170+
index_dtype=param.index_dtype)
174171

175172

176173
def _log_mlir(bytecode_buf):

src/cuda/tile/_execution.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def __init__(self,
105105
occupancy=occupancy,
106106
opt_level=opt_level
107107
)
108-
super().__init__(ann_func.constant_parameter_mask)
108+
super().__init__(ann_func.constant_parameter_mask, ann_func.int64_index_parameter_mask,
109+
ann_func.int64_parameter_mask)
109110
self._annotated_function = ann_func
110111
self._compiler_options = compiler_options
111112

src/cuda/tile/_ir/load_store_impl.py

Lines changed: 0 additions & 126 deletions
This file was deleted.

src/cuda/tile/_ir/ops.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,12 +1326,20 @@ def generate_bytecode(self, ctx: BytecodeContext):
13261326
)
13271327

13281328
# Cast each of the i64 words to appropriate types
1329+
if list_ty.item_type.index_dtype.bitwidth >= 64:
1330+
# Already i64, no truncation needed
1331+
shape_stride_results = list(extracted_words[1:])
1332+
else:
1333+
shape_stride_results = [
1334+
bc.encode_TruncIOp(ctx.builder, ty_id, w, bc.IntegerOverflow.NONE)
1335+
for ty_id, w in zip(item_typeid_tuple[1:], extracted_words[1:], strict=True)
1336+
]
1337+
13291338
return (
13301339
# Cast the first word to data pointer
13311340
bc.encode_IntToPtrOp(ctx.builder, item_typeid_tuple[0], extracted_words[0]),
1332-
# Cast the remaining words to i32 shape/strides
1333-
*(bc.encode_TruncIOp(ctx.builder, ty, w, bc.IntegerOverflow.NONE)
1334-
for ty, w in zip(item_typeid_tuple[1:], extracted_words[1:], strict=True))
1341+
# Cast the remaining words to shape/stride types (i32 or i64)
1342+
*shape_stride_results
13351343
)
13361344

13371345

@@ -2260,6 +2268,7 @@ def maybe_const_int(v: Var):
22602268
array_ty.element_type,
22612269
shape=new_shape_ty,
22622270
strides=array_ty.strides,
2271+
index_dtype=array_ty.index_dtype,
22632272
)
22642273

22652274
array_val = array.get_aggregate()
@@ -2334,12 +2343,15 @@ class TileLoad(Operation, opcode="tile_load", memory_effect=MemoryEffect.LOAD):
23342343
@override
23352344
def generate_bytecode(self, ctx: BytecodeContext) -> tuple[bc.Value, bc.Value]:
23362345
tile_type: TileTy = self.result_vars[0].get_type()
2346+
view_ty = self.view.get_type()
2347+
keep_i64 = (isinstance(view_ty, PartitionViewTy)
2348+
and view_ty.array_ty.index_dtype.bitwidth > 32)
23372349
res, res_token = bc.encode_LoadViewTkoOp(
23382350
ctx.builder,
23392351
tile_type=typeid(ctx.type_table, tile_type),
23402352
result_token_type=ctx.type_table.Token,
23412353
view=ctx.get_value(self.view),
2342-
index=ctx.index_tuple(self.index),
2354+
index=ctx.index_tuple(self.index, keep_i64=keep_i64),
23432355
token=None if self.token is None else ctx.get_value(self.token),
23442356
memory_ordering_semantics=memory_order_to_bytecode[self.memory_order],
23452357
memory_scope=memory_scope_to_bytecode[self.memory_scope],
@@ -2359,6 +2371,11 @@ def _tile_load_impl_inner(array: Var, index_items: tuple[Var, ...], shape: Seque
23592371
allow_tma = require_optional_constant_bool(allow_tma)
23602372
_check_load_store_hints(latency, allow_tma)
23612373

2374+
# Promote indices to i64 for big arrays so that blockId * tileSize
2375+
# doesn't overflow i32 in the backend's address computation.
2376+
if array_ty.index_dtype.bitwidth > 32:
2377+
index_items = tuple(astype(idx, array_ty.index_dtype) for idx in index_items)
2378+
23622379
view = make_partition_view(array, broadcasted_shape, order, padding_mode)
23632380
res_ty = make_tile_ty(array_ty.dtype, broadcasted_shape)
23642381
result, _token = add_operation(TileLoad, (res_ty, TokenTy()),
@@ -2482,12 +2499,15 @@ class TileStore(Operation, opcode="tile_store", memory_effect=MemoryEffect.STORE
24822499

24832500
@override
24842501
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
2502+
view_ty = self.view.get_type()
2503+
keep_i64 = (isinstance(view_ty, PartitionViewTy)
2504+
and view_ty.array_ty.index_dtype.bitwidth > 32)
24852505
return bc.encode_StoreViewTkoOp(
24862506
ctx.builder,
24872507
result_token_type=ctx.type_table.Token,
24882508
tile=ctx.get_value(self.tile),
24892509
view=ctx.get_value(self.view),
2490-
index=ctx.index_tuple(self.index),
2510+
index=ctx.index_tuple(self.index, keep_i64=keep_i64),
24912511
token=None if self.token is None else ctx.get_value(self.token),
24922512
memory_ordering_semantics=memory_order_to_bytecode[self.memory_order],
24932513
memory_scope=memory_scope_to_bytecode[self.memory_scope],
@@ -2517,6 +2537,11 @@ def _tile_store_impl_inner(array: Var, index_items: tuple[Var, ...], tile: Var,
25172537
allow_tma = require_optional_constant_bool(allow_tma)
25182538
_check_load_store_hints(latency, allow_tma)
25192539

2540+
# Promote indices to i64 for big arrays so that blockId * tileSize
2541+
# doesn't overflow i32 in the backend's address computation.
2542+
if array_ty.index_dtype.bitwidth > 32:
2543+
index_items = tuple(astype(idx, array_ty.index_dtype) for idx in index_items)
2544+
25202545
tile = reshape(tile, broadcasted_shape)
25212546
view = make_partition_view(array, broadcasted_shape, order, PaddingMode.UNDETERMINED)
25222547
[_token] = add_operation(TileStore, (TokenTy(),), view=view, index=index_items, tile=tile,

src/cuda/tile/_ir2bytecode.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,13 @@ def constant_tuple(self, value, ty: Type) -> Tuple[bc.Value, ...]:
362362
for item_ty, item_val in zip(ty.value_types, value, strict=True)), ())
363363
return self.constant(value, ty),
364364

365-
def index_tuple(self, index: tuple[Var, ...]) -> Tuple[bc.Value, ...]:
365+
def index_tuple(self,
366+
index: tuple[Var, ...], *, keep_i64: bool = False) -> Tuple[bc.Value, ...]:
366367
i32_tile_ty = self.type_table.tile(self.type_table.I32, ())
367368
item_types = tuple(x.get_type() for x in index)
368369
index_values = tuple(self.get_value(x) for x in index)
370+
if keep_i64:
371+
return index_values
369372
return tuple(
370373
bc.encode_TruncIOp(self.builder, i32_tile_ty, v, bc.IntegerOverflow.NONE)
371374
if (t.dtype if isinstance(t, TileTy) else t).bitwidth > 32 else v

0 commit comments

Comments
 (0)