Skip to content

Commit 286fb81

Browse files
committed
Fix fp8 matmul reference for sm90
Currently we need `use_fast_accum` on `torch._scaled_mm` to match the behavior of `ct.mma` and `ct.scaled_mma` with fp8 on sm90. Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent 7fb3407 commit 286fb81

File tree

4 files changed

+9
-5
lines changed

4 files changed

+9
-5
lines changed

samples/BatchMatMul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def torch_batch_matmul_fp8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
103103
A_row = A[i].contiguous()
104104
B_col = B[i].transpose(-2, -1).contiguous().transpose(-2, -1)
105105
C[i] = torch._scaled_mm(
106-
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32
106+
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32,
107+
use_fast_accum=True
107108
)
108109
return C
109110

samples/templates/BatchMatMul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def torch_batch_matmul_fp8(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
6464
A_row = A[i].contiguous()
6565
B_col = B[i].transpose(-2, -1).contiguous().transpose(-2, -1)
6666
C[i] = torch._scaled_mm(
67-
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32
67+
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32,
68+
use_fast_accum=True
6869
)
6970
return C
7071

test/bench_matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def torch_batch_matmul(bs, A, B, C):
181181
A_row = A[i].contiguous()
182182
B_col = B[i].transpose(-2, -1).contiguous().transpose(-2, -1)
183183
C[i] = torch._scaled_mm(
184-
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32
184+
A_row, B_col, scale_a=inv_sa, scale_b=inv_sb, out_dtype=torch.float32,
185+
use_fast_accum=True
185186
)
186187

187188

test/test_mma.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def test_mma_fp8(tile_size, case):
125125
C = torch.ones((m, n), dtype=case.acc_dtype, device="cuda")
126126
scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
127127
try:
128-
ref = torch._scaled_mm(A, B.T, scale, scale, out_dtype=C.dtype) + C
128+
ref = torch._scaled_mm(A, B.T, scale, scale, out_dtype=C.dtype, use_fast_accum=True) + C
129129
except (RuntimeError, ValueError) as e:
130130
assert 'Multiplication of two Float8_e5m2 matrices is not supported' in str(e)
131131
ref = None
@@ -279,7 +279,8 @@ def test_matmul_fp8(tile_size, dtype):
279279
C = torch.zeros((m, n), dtype=dtype, device="cuda")
280280
scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
281281
try:
282-
ref = torch._scaled_mm(A, B.T, scale, scale, out_dtype=torch.float16).to(dtype)
282+
ref = torch._scaled_mm(A, B.T, scale, scale,
283+
out_dtype=torch.float16, use_fast_accum=True).to(dtype)
283284
except (RuntimeError, ValueError) as e:
284285
assert 'Multiplication of two Float8_e5m2 matrices is not supported' in str(e)
285286
ref = None

0 commit comments

Comments
 (0)