Skip to content

Commit 84daa7d

Browse files
committed
Update encodings and supported bytecode version to 13.3
Signed-off-by: Qiqi Xiao <qiqix@nvidia.com>
1 parent bc04009 commit 84daa7d

7 files changed

Lines changed: 35 additions & 17 deletions

File tree

samples/BatchMatMul.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def bmm(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype) -> torch.Tenso
8080
output = torch.empty((Batch, M, N), device=a.device, dtype=out_dtype)
8181

8282
# --- Determine Tile Shapes for Optimization (Fixed for float16 as per previous request) ---
83-
tm_val, tn_val, tk_val = 128, 256, 64 # Larger tiles for Tensor Core benefits
83+
tm_val, tn_val, tk_val = 128, 256, 128 # Larger tiles for Tensor Core benefits
8484

8585
# --- Grid calculation for standard 3D tiled kernel ---
8686
grid = (Batch, ceil(M / tm_val), ceil(N / tn_val))
@@ -103,8 +103,7 @@ def torch_batch_matmul_fp8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
103103
A_row = A[i].contiguous()
104104
B_col = B[i].transpose(-2, -1).contiguous().transpose(-2, -1)
105105
C[i] = torch._scaled_mm(
106-
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32,
107-
use_fast_accum=True
106+
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32
108107
)
109108
return C
110109

samples/templates/BatchMatMul.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def bmm(a: torch.Tensor, b: torch.Tensor, out_dtype: torch.dtype) -> torch.Tenso
4141
output = torch.empty((Batch, M, N), device=a.device, dtype=out_dtype)
4242

4343
# --- Determine Tile Shapes for Optimization (Fixed for float16 as per previous request) ---
44-
tm_val, tn_val, tk_val = 128, 256, 64 # Larger tiles for Tensor Core benefits
44+
tm_val, tn_val, tk_val = 128, 256, 128 # Larger tiles for Tensor Core benefits
4545

4646
# --- Grid calculation for standard 3D tiled kernel ---
4747
grid = (Batch, ceil(M / tm_val), ceil(N / tn_val))
@@ -64,8 +64,7 @@ def torch_batch_matmul_fp8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
6464
A_row = A[i].contiguous()
6565
B_col = B[i].transpose(-2, -1).contiguous().transpose(-2, -1)
6666
C[i] = torch._scaled_mm(
67-
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32,
68-
use_fast_accum=True
67+
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32
6968
)
7069
return C
7170

src/cuda/tile/_bytecode/encodings.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,18 @@ class MemoryScope(enum.Enum):
6363
SYS = b"\x02"
6464

6565

66+
class ProgramIDDim(enum.Enum):
67+
X = b"\x00"
68+
Y = b"\x01"
69+
Z = b"\x02"
70+
71+
72+
class PtrAttr(enum.Enum):
73+
NONE = b"\x00"
74+
UNICAST = b"\x01"
75+
MULTICAST = b"\x02"
76+
77+
6678
class RoundingMode(enum.Enum):
6779
NEAREST_EVEN = b"\x00"
6880
ZERO = b"\x01"
@@ -313,8 +325,8 @@ def encode_AtomicRedViewTkoOp( # since 13.3
313325
code_builder: CodeBuilder,
314326
result_token_type: TypeId, # since 13.3
315327
view: Value, # since 13.3
328+
index: Sequence[Value], # since 13.3
316329
value: Value, # since 13.3
317-
mask: Optional[Value], # since 13.3
318330
token: Optional[Value], # since 13.3
319331
memory_ordering_semantics: MemoryOrderingSemantics, # since 13.3
320332
memory_scope: MemoryScope, # since 13.3
@@ -323,19 +335,18 @@ def encode_AtomicRedViewTkoOp( # since 13.3
323335
_buf = code_builder.buf
324336
# Opcode
325337
encode_varint(117, _buf)
326-
# Result types
327-
encode_typeid(result_token_type, _buf)
338+
# Variadic result types
339+
encode_sized_typeid_seq((result_token_type,), _buf)
328340
# Flags
329-
encode_varint((mask is not None)
330-
| ((token is not None) << 1), _buf)
341+
encode_varint((token is not None), _buf)
331342
# Attributes
332343
code_builder.encode_opattr_enum(MemoryOrderingSemantics, memory_ordering_semantics)
333344
code_builder.encode_opattr_enum(MemoryScope, memory_scope)
334345
code_builder.encode_opattr_enum(AtomicRMWMode, mode)
335346
# Operands
336347
encode_operand(view, _buf)
348+
encode_sized_variadic_operands(index, _buf)
337349
encode_operand(value, _buf)
338-
encode_optional_operand(mask, _buf)
339350
encode_optional_operand(token, _buf)
340351
return code_builder.new_op()
341352

@@ -1242,12 +1253,18 @@ def encode_MmaFOp(
12421253
lhs: Value,
12431254
rhs: Value,
12441255
acc: Value,
1256+
fast_acc: bool, # since 13.3
12451257
) -> Value:
12461258
_buf = code_builder.buf
12471259
# Opcode
12481260
encode_varint(73, _buf)
12491261
# Result types
12501262
encode_typeid(result_type, _buf)
1263+
# Flags
1264+
_flag_bits = bool(fast_acc)
1265+
assert _flag_bits < 1 or code_builder.version >= BytecodeVersion.V_13_3
1266+
if code_builder.version >= BytecodeVersion.V_13_3:
1267+
encode_varint(_flag_bits, _buf)
12511268
# Operands
12521269
encode_operand(lhs, _buf)
12531270
encode_operand(rhs, _buf)
@@ -2024,6 +2041,8 @@ def encode_YieldOp(
20242041
'IntegerOverflow',
20252042
'MemoryOrderingSemantics',
20262043
'MemoryScope',
2044+
'ProgramIDDim',
2045+
'PtrAttr',
20272046
'RoundingMode',
20282047
'Signedness',
20292048
'SymbolVisibility',

src/cuda/tile/_compile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ def _find_compiler_bin() -> _CompilerBinary:
571571
_SUPPORTED_VERSIONS = [
572572
BytecodeVersion.V_13_1,
573573
BytecodeVersion.V_13_2,
574+
BytecodeVersion.V_13_3,
574575
]
575576

576577

src/cuda/tile/_ir/ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3217,8 +3217,9 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
32173217
return bc.encode_MmaIOp(ctx.builder, res_typeid, x_value, y_value,
32183218
acc_value, signedness_lhs, signedness_rhs)
32193219
else:
3220+
# TODO: consider expose fast_acc
32203221
return bc.encode_MmaFOp(ctx.builder, res_typeid, x_value, y_value,
3221-
acc_value)
3222+
acc_value, fast_acc=False)
32223223

32233224

32243225
@impl(ct.mma)

test/bench_matmul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,7 @@ def torch_batch_matmul(bs, A, B, C):
181181
A_row = A[i].contiguous()
182182
B_col = B[i].transpose(-2, -1).contiguous().transpose(-2, -1)
183183
C[i] = torch._scaled_mm(
184-
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32,
185-
use_fast_accum=True
184+
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32
186185
)
187186

188187

test/test_mma.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_mma_fp8(tile_size, case):
125125
C = torch.ones((m, n), dtype=case.acc_dtype, device="cuda")
126126
scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
127127
try:
128-
ref = torch._scaled_mm(A, B.T, scale, scale, out_dtype=C.dtype, use_fast_accum=True) + C
128+
ref = torch._scaled_mm(A, B.T, scale, scale, out_dtype=C.dtype) + C
129129
except (RuntimeError, ValueError) as e:
130130
assert 'Multiplication of two Float8_e5m2 matrices is not supported' in str(e)
131131
ref = None
@@ -280,7 +280,7 @@ def test_matmul_fp8(tile_size, dtype):
280280
scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
281281
try:
282282
ref = torch._scaled_mm(A, B.T, scale, scale,
283-
out_dtype=torch.float16, use_fast_accum=True)
283+
out_dtype=torch.float16)
284284
except (RuntimeError, ValueError) as e:
285285
assert 'Multiplication of two Float8_e5m2 matrices is not supported' in str(e)
286286
ref = None

0 commit comments

Comments
 (0)