@@ -1326,12 +1326,20 @@ def generate_bytecode(self, ctx: BytecodeContext):
13261326 )
13271327
13281328 # Cast each of the i64 words to appropriate types
1329+ if list_ty .item_type .index_dtype .bitwidth >= 64 :
1330+ # Already i64, no truncation needed
1331+ shape_stride_results = list (extracted_words [1 :])
1332+ else :
1333+ shape_stride_results = [
1334+ bc .encode_TruncIOp (ctx .builder , ty_id , w , bc .IntegerOverflow .NONE )
1335+ for ty_id , w in zip (item_typeid_tuple [1 :], extracted_words [1 :], strict = True )
1336+ ]
1337+
13291338 return (
13301339 # Cast the first word to data pointer
13311340 bc .encode_IntToPtrOp (ctx .builder , item_typeid_tuple [0 ], extracted_words [0 ]),
1332- # Cast the remaining words to i32 shape/strides
1333- * (bc .encode_TruncIOp (ctx .builder , ty , w , bc .IntegerOverflow .NONE )
1334- for ty , w in zip (item_typeid_tuple [1 :], extracted_words [1 :], strict = True ))
1341+ # Cast the remaining words to shape/stride types (i32 or i64)
1342+ * shape_stride_results
13351343 )
13361344
13371345
@@ -2260,6 +2268,7 @@ def maybe_const_int(v: Var):
22602268 array_ty .element_type ,
22612269 shape = new_shape_ty ,
22622270 strides = array_ty .strides ,
2271+ index_dtype = array_ty .index_dtype ,
22632272 )
22642273
22652274 array_val = array .get_aggregate ()
@@ -2334,12 +2343,15 @@ class TileLoad(Operation, opcode="tile_load", memory_effect=MemoryEffect.LOAD):
23342343 @override
23352344 def generate_bytecode (self , ctx : BytecodeContext ) -> tuple [bc .Value , bc .Value ]:
23362345 tile_type : TileTy = self .result_vars [0 ].get_type ()
2346+ view_ty = self .view .get_type ()
2347+ keep_i64 = (isinstance (view_ty , PartitionViewTy )
2348+ and view_ty .array_ty .index_dtype .bitwidth > 32 )
23372349 res , res_token = bc .encode_LoadViewTkoOp (
23382350 ctx .builder ,
23392351 tile_type = typeid (ctx .type_table , tile_type ),
23402352 result_token_type = ctx .type_table .Token ,
23412353 view = ctx .get_value (self .view ),
2342- index = ctx .index_tuple (self .index ),
2354+ index = ctx .index_tuple (self .index , keep_i64 = keep_i64 ),
23432355 token = None if self .token is None else ctx .get_value (self .token ),
23442356 memory_ordering_semantics = memory_order_to_bytecode [self .memory_order ],
23452357 memory_scope = memory_scope_to_bytecode [self .memory_scope ],
@@ -2359,6 +2371,11 @@ def _tile_load_impl_inner(array: Var, index_items: tuple[Var, ...], shape: Seque
23592371 allow_tma = require_optional_constant_bool (allow_tma )
23602372 _check_load_store_hints (latency , allow_tma )
23612373
2374+ # Promote indices to i64 for big arrays so that blockId * tileSize
2375+ # doesn't overflow i32 in the backend's address computation.
2376+ if array_ty .index_dtype .bitwidth > 32 :
2377+ index_items = tuple (astype (idx , array_ty .index_dtype ) for idx in index_items )
2378+
23622379 view = make_partition_view (array , broadcasted_shape , order , padding_mode )
23632380 res_ty = make_tile_ty (array_ty .dtype , broadcasted_shape )
23642381 result , _token = add_operation (TileLoad , (res_ty , TokenTy ()),
@@ -2482,12 +2499,15 @@ class TileStore(Operation, opcode="tile_store", memory_effect=MemoryEffect.STORE
24822499
24832500 @override
24842501 def generate_bytecode (self , ctx : BytecodeContext ) -> bc .Value :
2502+ view_ty = self .view .get_type ()
2503+ keep_i64 = (isinstance (view_ty , PartitionViewTy )
2504+ and view_ty .array_ty .index_dtype .bitwidth > 32 )
24852505 return bc .encode_StoreViewTkoOp (
24862506 ctx .builder ,
24872507 result_token_type = ctx .type_table .Token ,
24882508 tile = ctx .get_value (self .tile ),
24892509 view = ctx .get_value (self .view ),
2490- index = ctx .index_tuple (self .index ),
2510+ index = ctx .index_tuple (self .index , keep_i64 = keep_i64 ),
24912511 token = None if self .token is None else ctx .get_value (self .token ),
24922512 memory_ordering_semantics = memory_order_to_bytecode [self .memory_order ],
24932513 memory_scope = memory_scope_to_bytecode [self .memory_scope ],
@@ -2517,6 +2537,11 @@ def _tile_store_impl_inner(array: Var, index_items: tuple[Var, ...], tile: Var,
25172537 allow_tma = require_optional_constant_bool (allow_tma )
25182538 _check_load_store_hints (latency , allow_tma )
25192539
2540+ # Promote indices to i64 for big arrays so that blockId * tileSize
2541+ # doesn't overflow i32 in the backend's address computation.
2542+ if array_ty .index_dtype .bitwidth > 32 :
2543+ index_items = tuple (astype (idx , array_ty .index_dtype ) for idx in index_items )
2544+
25202545 tile = reshape (tile , broadcasted_shape )
25212546 view = make_partition_view (array , broadcasted_shape , order , PaddingMode .UNDETERMINED )
25222547 [_token ] = add_operation (TileStore , (TokenTy (),), view = view , index = index_items , tile = tile ,
0 commit comments