Skip to content

Commit fb2bdd0

Browse files
committed
pack_to_bytes and unpack_from_bytes
Signed-off-by: Boyan Li <boyanl@nvidia.com>
1 parent 3f342ff commit fb2bdd0

File tree

7 files changed

+341
-4
lines changed

7 files changed

+341
-4
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- New `ct.pack_to_bytes()` operation that flattens a tile and reinterprets its
5+
raw bytes as a 1D uint8 tile.
6+
- New `ct.unpack_from_bytes()` operation that reinterprets a 1D uint8 tile as a
7+
1D tile of the target dtype. Inverse of `ct.pack_to_bytes()`.

docs/source/operations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Shape & DType
5353
permute
5454
transpose
5555
astype
56+
bitcast
5657

5758

5859
Reduction

src/cuda/tile/_ir/ops.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3642,9 +3642,17 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
36423642
def bitcast(x: Var, dtype: DType) -> Var:
36433643
tile_ty = require_tile_type(x)
36443644
x_dtype = tile_ty.dtype
3645+
if x_dtype == datatype.bool_ or dtype == datatype.bool_:
3646+
raise TileTypeError(f"Cannot bitcast from {x_dtype} to {dtype}: "
3647+
f"bitcast to or from bool is not supported")
3648+
36453649
if x_dtype.bitwidth != dtype.bitwidth:
36463650
raise TileTypeError(f"Cannot bitcast from {x_dtype} to {dtype}: "
3647-
f"bit width is different ({x_dtype.bitwidth} vs. {dtype.bitwidth}")
3651+
f"bit width is different ({x_dtype.bitwidth} vs. {dtype.bitwidth})")
3652+
3653+
if x_dtype == dtype:
3654+
return x
3655+
36483656
res_ty = make_tile_ty(dtype, tile_ty.shape_value)
36493657
return add_operation(TileBitCast, res_ty, x=x)
36503658

@@ -3655,6 +3663,90 @@ def bitcast_impl(x: Var, dtype: Var) -> Var:
36553663
return bitcast(x, dtype_val)
36563664

36573665

3666+
@dataclass(eq=False)
3667+
class TilePack(Operation, opcode="tile_pack"):
3668+
x: Var = operand()
3669+
3670+
@override
3671+
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
3672+
res_type_id = ctx.typeid_of(self.result_var)
3673+
x_value = ctx.get_value(self.x)
3674+
return bc.encode_PackOp(ctx.builder, res_type_id, x_value)
3675+
3676+
3677+
def pack(x: Var) -> Var:
3678+
tile_ty = require_tile_type(x)
3679+
assert tile_ty.ndim == 1
3680+
assert tile_ty.dtype.bitwidth != 8
3681+
old_dim = tile_ty.shape_value[0]
3682+
new_dim, rem = divmod(old_dim * tile_ty.dtype.bitwidth, 8)
3683+
if rem != 0:
3684+
raise TileTypeError(f"Cannot pack tile {tile_ty}: "
3685+
f"total bits ({old_dim} * {tile_ty.dtype.bitwidth}) "
3686+
f"not divisible by 8")
3687+
res_ty = make_tile_ty(datatype.uint8, (new_dim,))
3688+
return add_operation(TilePack, res_ty, x=x)
3689+
3690+
3691+
@impl(ct.pack_to_bytes, min_version=BytecodeVersion.V_13_3)
3692+
def pack_to_bytes_impl(x: Var):
3693+
tile_ty = require_tile_type(x)
3694+
x_dtype = tile_ty.dtype
3695+
x = reshape(x, (-1,))
3696+
if x_dtype == datatype.bool_:
3697+
raise TileTypeError(f"pack_to_bytes from a {x_dtype} tile is not supported")
3698+
3699+
if x_dtype.bitwidth == 8:
3700+
return bitcast(x, datatype.uint8)
3701+
return pack(x)
3702+
3703+
3704+
@dataclass(eq=False)
3705+
class TileUnpack(Operation, opcode="tile_unpack"):
3706+
x: Var = operand()
3707+
3708+
@override
3709+
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
3710+
res_type_id = ctx.typeid_of(self.result_var)
3711+
x_value = ctx.get_value(self.x)
3712+
return bc.encode_UnpackOp(ctx.builder, res_type_id, x_value)
3713+
3714+
3715+
def unpack(x: Var, dtype: DType) -> Var:
3716+
tile_ty = require_tile_type(x)
3717+
assert tile_ty.ndim == 1
3718+
assert tile_ty.dtype == datatype.uint8
3719+
assert dtype.bitwidth != 8
3720+
old_dim = tile_ty.shape_value[0]
3721+
new_dim, rem = divmod(old_dim * 8, dtype.bitwidth)
3722+
if rem != 0:
3723+
raise TileTypeError(
3724+
f"Cannot unpack tile {tile_ty} to {dtype}: "
3725+
f"total bits ({old_dim} * 8) not divisible by {dtype.bitwidth}")
3726+
res_ty = make_tile_ty(dtype, (new_dim,))
3727+
return add_operation(TileUnpack, res_ty, x=x)
3728+
3729+
3730+
@impl(ct.unpack_from_bytes, min_version=BytecodeVersion.V_13_3)
3731+
def unpack_from_bytes_impl(x: Var, dtype: Var):
3732+
tile_ty = require_tile_type(x)
3733+
x_dtype = tile_ty.dtype
3734+
dtype = require_dtype_spec(dtype)
3735+
if tile_ty.ndim != 1:
3736+
raise TileTypeError(
3737+
f"unpack_from_bytes requires a 1D tile, "
3738+
f"got {tile_ty.ndim}D tile with shape {tile_ty.shape_value}")
3739+
if x_dtype != datatype.uint8:
3740+
raise TileTypeError(
3741+
f"unpack_from_bytes requires uint8 tile, got {x_dtype} tile")
3742+
if dtype == datatype.bool_:
3743+
raise TileTypeError(f"unpack_from_bytes to a {dtype} tile is not supported")
3744+
3745+
if dtype.bitwidth == 8:
3746+
return bitcast(x, dtype)
3747+
return unpack(x, dtype)
3748+
3749+
36583750
@dataclass(eq=False)
36593751
class TileArange(Operation, opcode="tile_arange"):
36603752
@override

src/cuda/tile/_stub.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,55 @@ def bitcast(x, /, dtype) -> Tile:
12321232
"""
12331233

12341234

1235+
@function
1236+
def pack_to_bytes(x, /) -> Tile:
1237+
"""Flattens a tile and reinterprets its raw bytes as uint8 elements.
1238+
1239+
The total number of bits of the input tile must be divisible by 8.
1240+
1241+
Args:
1242+
x (Tile): input tile.
1243+
1244+
Returns:
1245+
Tile: a 1D uint8 tile with ``total_elements * bit width // 8`` elements.
1246+
1247+
Examples:
1248+
1249+
>>> tx = ct.full((2, 4), 0, dtype=ct.int32)
1250+
>>> ty = ct.pack_to_bytes(tx)
1251+
>>> ty.dtype
1252+
uint8
1253+
>>> ty.shape
1254+
(32,)
1255+
"""
1256+
1257+
1258+
@function
1259+
def unpack_from_bytes(x, /, dtype) -> Tile:
1260+
"""Reinterprets a 1D uint8 byte tile as a 1D tile of the target data type.
1261+
1262+
The inverse of :py:func:`pack_to_bytes`. The input must be a 1D tile of
1263+
dtype uint8, and the total number of bits must be divisible by the
1264+
target data type bit width.
1265+
1266+
Args:
1267+
x (Tile): a 1D tile of dtype uint8.
1268+
dtype (DType): target data type.
1269+
1270+
Returns:
1271+
Tile: a 1D tile of ``dtype`` with ``num_bytes * 8 // bit width`` elements.
1272+
1273+
Examples:
1274+
1275+
>>> tx = ct.full((16,), 0, dtype=ct.uint8)
1276+
>>> ty = ct.unpack_from_bytes(tx, ct.float32)
1277+
>>> ty.dtype
1278+
float32
1279+
>>> ty.shape
1280+
(4,)
1281+
"""
1282+
1283+
12351284
def _math_op_extra_block(f, indent):
12361285
base = inspect.unwrap(f)
12371286
sig = inspect.signature(base)

test/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def shape_size_id(shape):
106106
float_dtypes = [torch.float16, torch.bfloat16, torch.float32]
107107
int_dtypes = [torch.int32, torch.int64, torch.int16, torch.int8]
108108
bool_dtypes = [torch.bool]
109-
uint_dtypes = [torch.uint32, torch.uint64]
109+
uint_dtypes = [torch.uint8, torch.uint32, torch.uint64]
110110
arithmetic_dtypes = int_dtypes + uint_dtypes + float_dtypes + bool_dtypes
111111

112112

test/test_cast.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,24 @@ def test_cast_tf32(dtype):
9595
(torch.int32, torch.int64),
9696
(torch.int64, torch.float32),
9797
(torch.float16, torch.int32),
98+
# failing pairs with bool
99+
(torch.bool, torch.int8),
100+
(torch.uint8, torch.bool),
101+
(torch.bool, torch.bool),
98102
])
99103
def test_array_bitcast(shape, tile, dtype_x, dtype_y):
100104
# avoid inputs that could produce nans of infs to not break assert
101-
if dtype_x in (torch.int32, torch.int64):
105+
if dtype_x == torch.bool:
106+
x = torch.randint(0, 2, shape, dtype=dtype_x, device='cuda')
107+
elif dtype_x in (torch.int32, torch.int64, torch.int8, torch.uint8):
102108
x = torch.randint(0, 100, shape, dtype=dtype_x, device='cuda')
103109
else:
104110
x = torch.randn(shape, dtype=dtype_x, device='cuda')
105111
ref = x.view(dtype=dtype_y)
106112
y = torch.zeros_like(ref)
107113
grid = (ceil(shape[0] / tile), 1, 1)
108-
if dtype_x.itemsize != dtype_y.itemsize:
114+
if (dtype_x == torch.bool or dtype_y == torch.bool
115+
or dtype_x.itemsize != dtype_y.itemsize):
109116
with pytest.raises(TileTypeError):
110117
ct.launch(torch.cuda.current_stream(), grid, array_bitcast, (x, y, tile))
111118

test/test_pack_unpack.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import pytest
6+
import torch
7+
from torch.testing import make_tensor
8+
9+
import cuda.tile as ct
10+
from cuda.tile._bytecode.version import BytecodeVersion
11+
from util import assert_equal
12+
from cuda.tile._exception import TileTypeError
13+
from conftest import float_dtypes, int_dtypes, requires_tileiras, uint_dtypes, dtype_id
14+
15+
# TODO: remove when feature is out of development only
16+
from cuda.tile._stub import pack_to_bytes, unpack_from_bytes
17+
ct.pack_to_bytes = pack_to_bytes
18+
ct.unpack_from_bytes = unpack_from_bytes
19+
20+
pytestmark = requires_tileiras(BytecodeVersion.V_13_3)
21+
22+
test_dtypes = float_dtypes + int_dtypes + uint_dtypes + [torch.float64]
23+
24+
25+
@ct.kernel
26+
def pack_unpack_1d(x, y, TILE: ct.Constant[int]):
27+
tx = ct.load(x, index=(0,), shape=(TILE,))
28+
packed = ct.pack_to_bytes(tx)
29+
ty = ct.unpack_from_bytes(packed, y.dtype)
30+
ct.store(y, index=(0,), tile=ty)
31+
32+
33+
@pytest.mark.parametrize("dtype", test_dtypes, ids=dtype_id)
34+
def test_pack_to_bytes(dtype):
35+
@ct.kernel
36+
def kernel(x, y, TILE: ct.Constant[int]):
37+
tx = ct.load(x, index=(0,), shape=(TILE,))
38+
ty = ct.pack_to_bytes(tx)
39+
ct.store(y, index=(0,), tile=ty)
40+
41+
tile = 128
42+
x = make_tensor((tile,), dtype=dtype, device='cuda')
43+
nbytes = tile * x.element_size()
44+
y = torch.zeros(nbytes, dtype=torch.uint8, device='cuda')
45+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y, tile))
46+
ref = x.view(torch.uint8)
47+
assert_equal(y, ref)
48+
49+
50+
@pytest.mark.parametrize("dtype", test_dtypes, ids=dtype_id)
51+
def test_unpack_from_bytes(dtype):
52+
@ct.kernel
53+
def kernel(x, y, TILE: ct.Constant[int]):
54+
tx = ct.load(x, index=(0,), shape=(TILE,))
55+
ty = ct.unpack_from_bytes(tx, y.dtype)
56+
ct.store(y, index=(0,), tile=ty)
57+
58+
ref = make_tensor((32,), dtype=dtype, device='cuda')
59+
x = ref.view(torch.uint8)
60+
y = torch.zeros_like(ref)
61+
tile = x.shape[0]
62+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y, tile))
63+
assert_equal(y, ref)
64+
65+
66+
@pytest.mark.parametrize("dtype", test_dtypes, ids=dtype_id)
67+
def test_pack_unpack_roundtrip(dtype):
68+
tile = 128
69+
x = make_tensor((tile,), dtype=dtype, device='cuda')
70+
y = torch.zeros_like(x)
71+
ct.launch(torch.cuda.current_stream(), (1,), pack_unpack_1d, (x, y, tile))
72+
assert_equal(y, x)
73+
74+
75+
@pytest.mark.parametrize("dtype", test_dtypes, ids=dtype_id)
76+
def test_pack_unpack_roundtrip_0d(dtype):
77+
@ct.kernel
78+
def kernel(x, y):
79+
tx = ct.gather(x, ())
80+
packed = ct.pack_to_bytes(tx)
81+
ty = ct.unpack_from_bytes(packed, x.dtype)
82+
ty = ty.reshape(())
83+
ct.scatter(y, (), ty)
84+
85+
x = make_tensor((), dtype=dtype, device='cuda')
86+
y = torch.zeros_like(x)
87+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y))
88+
assert_equal(y, x)
89+
90+
91+
@pytest.mark.parametrize("dtype", test_dtypes, ids=dtype_id)
92+
def test_pack_unpack_roundtrip_2d(dtype):
93+
@ct.kernel
94+
def kernel(x, y, TILE_M: ct.Constant[int], TILE_N: ct.Constant[int]):
95+
bidm = ct.bid(0)
96+
bidn = ct.bid(1)
97+
tx = ct.load(x, index=(bidm, bidn), shape=(TILE_M, TILE_N))
98+
packed = ct.pack_to_bytes(tx)
99+
ty = ct.unpack_from_bytes(packed, x.dtype)
100+
ty = ct.reshape(ty, (TILE_M, TILE_N))
101+
ct.store(y, index=(bidm, bidn), tile=ty)
102+
103+
shape = (64, 128)
104+
tiles = (32, 64)
105+
x = make_tensor(shape, dtype=dtype, device='cuda')
106+
y = torch.zeros_like(x)
107+
grid = (ct.cdiv(shape[0], tiles[0]), ct.cdiv(shape[1], tiles[1]))
108+
ct.launch(torch.cuda.current_stream(), grid,
109+
kernel, (x, y, tiles[0], tiles[1]))
110+
assert_equal(y, x)
111+
112+
113+
@pytest.mark.parametrize("dtype_x", test_dtypes, ids=dtype_id)
114+
@pytest.mark.parametrize("dtype_y", test_dtypes, ids=dtype_id)
115+
def test_cross_type_pack_unpack(dtype_x, dtype_y):
116+
tile = 128
117+
x = make_tensor((tile,), dtype=dtype_x, device='cuda')
118+
ref = x.view(torch.uint8).view(dtype_y)
119+
y = torch.zeros_like(ref)
120+
ct.launch(torch.cuda.current_stream(), (1,), pack_unpack_1d, (x, y, tile))
121+
assert_equal(y, ref)
122+
123+
124+
def test_unpack_from_bytes_not_divisible():
125+
@ct.kernel
126+
def kernel(x, y):
127+
tx = ct.load(x, index=(0,), shape=(2,))
128+
ct.unpack_from_bytes(tx, y.dtype)
129+
130+
x = torch.ones(2, dtype=torch.uint8, device='cuda')
131+
y = torch.zeros(1, dtype=torch.int32, device='cuda')
132+
with pytest.raises(TileTypeError, match="not divisible by 32"):
133+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y))
134+
135+
136+
def test_unpack_from_bytes_wrong_input_dtype():
137+
@ct.kernel
138+
def kernel(x, y):
139+
tx = ct.load(x, index=(0,), shape=(4,))
140+
ct.unpack_from_bytes(tx, y.dtype)
141+
142+
x = torch.ones(4, dtype=torch.int32, device='cuda')
143+
y = torch.zeros(4, dtype=torch.int32, device='cuda')
144+
with pytest.raises(TileTypeError, match="unpack_from_bytes requires uint8 tile"):
145+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y))
146+
147+
148+
def test_unpack_from_bytes_not_1d():
149+
@ct.kernel
150+
def kernel(x, y):
151+
tx = ct.load(x, index=(0, 0), shape=(4, 4))
152+
ct.unpack_from_bytes(tx, y.dtype)
153+
154+
x = torch.ones((4, 4), dtype=torch.uint8, device='cuda')
155+
y = torch.zeros(4, dtype=torch.int32, device='cuda')
156+
with pytest.raises(TileTypeError, match="unpack_from_bytes requires a 1D tile"):
157+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y))
158+
159+
160+
def test_pack_to_bytes_bool():
161+
@ct.kernel
162+
def kernel(x, y, TILE: ct.Constant[int]):
163+
tx = ct.load(x, index=(0,), shape=(TILE,))
164+
ct.pack_to_bytes(tx)
165+
166+
x = torch.ones(4, dtype=torch.bool, device='cuda')
167+
y = torch.zeros(4, dtype=torch.uint8, device='cuda')
168+
with pytest.raises(TileTypeError, match="pack_to_bytes from a bool_ tile"):
169+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y, 4))
170+
171+
172+
def test_unpack_from_bytes_bool():
173+
@ct.kernel
174+
def kernel(x, y):
175+
tx = ct.load(x, index=(0,), shape=(4,))
176+
ct.unpack_from_bytes(tx, y.dtype)
177+
178+
x = torch.ones(4, dtype=torch.uint8, device='cuda')
179+
y = torch.zeros(4, dtype=torch.bool, device='cuda')
180+
with pytest.raises(TileTypeError, match="unpack_from_bytes to a bool_ tile"):
181+
ct.launch(torch.cuda.current_stream(), (1,), kernel, (x, y))

0 commit comments

Comments
 (0)