Skip to content

Commit 2df0fc2

Browse files
committed
Raise UnsupportedFeatureError for FP8 on sm80 family
- Relax tf32 mma test tolerance for sm80 family - Fix rmsnorm kernel to add zero padding when load out of bound - Reduce tile size on sm80 family for persistent rmsnorm benchmark Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent 7bad878 commit 2df0fc2

16 files changed

Lines changed: 145 additions & 55 deletions
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<!--- SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -->
2+
<!--- SPDX-License-Identifier: Apache-2.0 -->
3+
4+
- Compiling FP8 operation for SM80 family will raise `TileUnsupportedFeatureError`
5+
- Add `TileUnsupportedFeatureError` to the public API

docs/source/debugging.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Exception Types
1313
.. autoclass:: TileSyntaxError()
1414
.. autoclass:: TileTypeError()
1515
.. autoclass:: TileValueError()
16+
.. autoclass:: TileUnsupportedFeatureError()
1617
.. autoclass:: TileCompilerExecutionError()
1718
.. autoclass:: TileCompilerTimeoutError()
1819

samples/BatchMatMul.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,26 @@ def torch_batch_matmul_fp8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
143143

144144
# --- Test Case 2: Standard BMM (float8_e4m3fn) ---
145145
print("\n--- Test 2: Standard BMM (float8_e4m3fn) ---")
146-
A_fp8 = torch.randn(
147-
BATCH_DIM, M_DIM, K_DIM, dtype=torch.float32, device='cuda'
148-
).to(torch.float8_e4m3fn)
149-
B_fp8 = torch.randn(
150-
BATCH_DIM, K_DIM, N_DIM, dtype=torch.float32, device='cuda'
151-
).to(torch.float8_e4m3fn)
152-
print(f"Input A shape: {A_fp8.shape}, dtype: {A_fp8.dtype}")
153-
print(f"Input B shape: {B_fp8.shape}, dtype: {B_fp8.dtype}")
154-
155-
C_bmm_cutile_fp32 = bmm(A_fp8, B_fp8, torch.float32)
156-
print(f"""cuTile Standard BMM Output C
157-
shape:{C_bmm_cutile_fp32.shape},
158-
dtype: {C_bmm_cutile_fp32.dtype}""")
159-
if args.correctness_check:
160-
torch.testing.assert_close(C_bmm_cutile_fp32, torch_batch_matmul_fp8(A_fp8, B_fp8))
161-
print("Correctness check passed")
146+
if torch.cuda.get_device_capability()[0] == 8:
147+
print("skip: Ampere does not support float8")
162148
else:
163-
print("Correctness check disabled")
149+
A_fp8 = torch.randn(
150+
BATCH_DIM, M_DIM, K_DIM, dtype=torch.float32, device='cuda'
151+
).to(torch.float8_e4m3fn)
152+
B_fp8 = torch.randn(
153+
BATCH_DIM, K_DIM, N_DIM, dtype=torch.float32, device='cuda'
154+
).to(torch.float8_e4m3fn)
155+
print(f"Input A shape: {A_fp8.shape}, dtype: {A_fp8.dtype}")
156+
print(f"Input B shape: {B_fp8.shape}, dtype: {B_fp8.dtype}")
157+
158+
C_bmm_cutile_fp32 = bmm(A_fp8, B_fp8, torch.float32)
159+
print(f"""cuTile Standard BMM Output C
160+
shape:{C_bmm_cutile_fp32.shape},
161+
dtype: {C_bmm_cutile_fp32.dtype}""")
162+
if args.correctness_check:
163+
torch.testing.assert_close(C_bmm_cutile_fp32, torch_batch_matmul_fp8(A_fp8, B_fp8))
164+
print("Correctness check passed")
165+
else:
166+
print("Correctness check disabled")
164167

165168
print("\n--- cuTile Batched Matrix Multiplication (Standard Tiled) examples complete ---")

samples/MatMul.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,11 @@ def cutile_matmul(A: torch.Tensor, B: torch.Tensor, persistent: bool = False) ->
294294
print(f"Input A shape: {A_fp32.shape}, dtype: {A_fp32.dtype}")
295295
print(f"Input B shape: {B_fp32.shape}, dtype: {B_fp32.dtype}")
296296

297-
atol, rtol = 1e-4, 1e-3
297+
if torch.cuda.get_device_capability()[0] <= 8:
298+
# Ampere tfloat32 numerics is loose
299+
atol, rtol = 5e-3, 5e-3
300+
else:
301+
atol, rtol = 1e-4, 1e-3
298302

299303
# Perform matrix multiplication using the cuTile wrapper function.
300304
C_fp32_cutile = cutile_matmul(A_fp32, B_fp32)

samples/templates/BatchMatMul.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,26 @@ def torch_batch_matmul_fp8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
104104

105105
# --- Test Case 2: Standard BMM (float8_e4m3fn) ---
106106
print("\n--- Test 2: Standard BMM (float8_e4m3fn) ---")
107-
A_fp8 = torch.randn(
108-
BATCH_DIM, M_DIM, K_DIM, dtype=torch.float32, device='cuda'
109-
).to(torch.float8_e4m3fn)
110-
B_fp8 = torch.randn(
111-
BATCH_DIM, K_DIM, N_DIM, dtype=torch.float32, device='cuda'
112-
).to(torch.float8_e4m3fn)
113-
print(f"Input A shape: {A_fp8.shape}, dtype: {A_fp8.dtype}")
114-
print(f"Input B shape: {B_fp8.shape}, dtype: {B_fp8.dtype}")
115-
116-
C_bmm_cutile_fp32 = bmm(A_fp8, B_fp8, torch.float32)
117-
print(f"""cuTile Standard BMM Output C
118-
shape:{C_bmm_cutile_fp32.shape},
119-
dtype: {C_bmm_cutile_fp32.dtype}""")
120-
if args.correctness_check:
121-
torch.testing.assert_close(C_bmm_cutile_fp32, torch_batch_matmul_fp8(A_fp8, B_fp8))
122-
print("Correctness check passed")
107+
if torch.cuda.get_device_capability()[0] == 8:
108+
print("skip: Ampere does not support float8")
123109
else:
124-
print("Correctness check disabled")
110+
A_fp8 = torch.randn(
111+
BATCH_DIM, M_DIM, K_DIM, dtype=torch.float32, device='cuda'
112+
).to(torch.float8_e4m3fn)
113+
B_fp8 = torch.randn(
114+
BATCH_DIM, K_DIM, N_DIM, dtype=torch.float32, device='cuda'
115+
).to(torch.float8_e4m3fn)
116+
print(f"Input A shape: {A_fp8.shape}, dtype: {A_fp8.dtype}")
117+
print(f"Input B shape: {B_fp8.shape}, dtype: {B_fp8.dtype}")
118+
119+
C_bmm_cutile_fp32 = bmm(A_fp8, B_fp8, torch.float32)
120+
print(f"""cuTile Standard BMM Output C
121+
shape:{C_bmm_cutile_fp32.shape},
122+
dtype: {C_bmm_cutile_fp32.dtype}""")
123+
if args.correctness_check:
124+
torch.testing.assert_close(C_bmm_cutile_fp32, torch_batch_matmul_fp8(A_fp8, B_fp8))
125+
print("Correctness check passed")
126+
else:
127+
print("Correctness check disabled")
125128

126129
print("\n--- cuTile Batched Matrix Multiplication (Standard Tiled) examples complete ---")

samples/templates/MatMul.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ def cutile_matmul(A: torch.Tensor, B: torch.Tensor, persistent: bool = False) ->
127127
print(f"Input A shape: {A_fp32.shape}, dtype: {A_fp32.dtype}")
128128
print(f"Input B shape: {B_fp32.shape}, dtype: {B_fp32.dtype}")
129129

130-
atol, rtol = 1e-4, 1e-3
130+
if torch.cuda.get_device_capability()[0] <= 8:
131+
# Ampere tfloat32 numerics is loose
132+
atol, rtol = 5e-3, 5e-3
133+
else:
134+
atol, rtol = 1e-4, 1e-3
131135

132136
# Perform matrix multiplication using the cuTile wrapper function.
133137
C_fp32_cutile = cutile_matmul(A_fp32, B_fp32)

src/cuda/tile/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
TileRecursionError,
4848
TileSyntaxError,
4949
TileTypeError,
50+
TileUnsupportedFeatureError,
5051
TileValueError,
5152
)
5253

@@ -181,6 +182,7 @@
181182
"TileRecursionError",
182183
"TileSyntaxError",
183184
"TileTypeError",
185+
"TileUnsupportedFeatureError",
184186
"TileValueError",
185187

186188
"Array",

src/cuda/tile/_compile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
)
4747

4848
from cuda.tile._passes.alias_analysis import alias_analysis_pass
49+
from cuda.tile._passes.check_ampere_fp8 import check_ampere_fp8
4950
from cuda.tile._passes.dce import dead_code_elimination_pass
5051
from cuda.tile._passes.token_order import token_order_pass
5152
from cuda.tile._ir2bytecode import generate_bytecode_for_kernel
@@ -195,6 +196,7 @@ def compile_tile(pyfunc,
195196
print(f'\n{code}', file=sys.stderr)
196197

197198
sm_arch = get_sm_arch()
199+
check_ampere_fp8(func_ir.body, sm_arch)
198200

199201
bytecode_generator = functools.partial(generate_bytecode_for_kernel,
200202
func_ir, compiler_options, sm_arch)

src/cuda/tile/_exception.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ class TileValueError(TileError):
145145
pass
146146

147147

148+
class TileUnsupportedFeatureError(TileError):
149+
"""Exception when a feature is not supported by the underlying compiler or
150+
the GPU architecture."""
151+
pass
152+
153+
148154
class TileInternalError(TileError):
149155
pass
150156

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cuda.tile._ir.ir import Block
6+
from cuda.tile._ir.type import TileTy, ArrayTy
7+
from cuda.tile._datatype import float8_e4m3fn, float8_e5m2, DType
8+
from cuda.tile._exception import TileUnsupportedFeatureError
9+
10+
FLOAT8_DTYPES = (float8_e4m3fn, float8_e5m2)
11+
12+
13+
def check_ampere_fp8(root_block: Block, sm_arch: str) -> None:
14+
# Technically sm_89 (Ada Lovelace) supports FP8, but tileiras doesn't have support for it yet.
15+
if not sm_arch.startswith("sm_8"):
16+
return
17+
18+
for op in root_block.traverse():
19+
for var in op.all_inputs():
20+
ty = var.try_get_type()
21+
dtype = None
22+
if isinstance(ty, (TileTy, ArrayTy)):
23+
dtype = ty.dtype
24+
elif isinstance(ty, DType):
25+
dtype = ty
26+
if dtype in FLOAT8_DTYPES:
27+
raise TileUnsupportedFeatureError(
28+
"float8 dtype is not supported on Ampere or Ada Lovelace (sm_8*) architecture",
29+
loc=op.loc
30+
)

0 commit comments

Comments
 (0)