Skip to content

Commit caee711

Browse files
committed
[Bench] Add option for max_swizzle_size
1 parent 07f67b6 commit caee711

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

benchmarks/benchmark_gemm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def parse_arguments() -> argparse.Namespace:
5050
parser.add_argument("--varlen_k", action="store_true", help="Variable length K dimension")
5151
parser.add_argument("--gather_A", action="store_true", help="Gather A")
5252
parser.add_argument("--use_tma_gather", action="store_true", help="Use TMA gather4 for A")
53+
parser.add_argument("--max_swizzle_size", type=int, default=8, help="Max swizzle size")
5354
parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking")
5455

5556
args = parser.parse_args()
@@ -109,7 +110,12 @@ def run(args):
109110
total_k = k * l
110111
cu_seqlens_k = torch.arange(0, l + 1, dtype=torch.int32, device=device) * k
111112
# m-major A, n-major B for varlen_k
112-
A = torch.randn(total_k, m, dtype=torch.bfloat16, device=device).T
113+
if gather_A:
114+
larger_k = total_k * 2
115+
A = torch.randn(larger_k, m, dtype=torch.bfloat16, device=device).T
116+
A_idx = torch.randperm(larger_k, dtype=torch.int32, device=device)[:total_k]
117+
else:
118+
A = torch.randn(total_k, m, dtype=torch.bfloat16, device=device).T
113119
B = torch.randn(total_k, n, dtype=torch.bfloat16, device=device).T
114120
D = torch.empty(l, m, n, dtype=torch.bfloat16, device=device)
115121
else:
@@ -127,6 +133,7 @@ def fn():
127133
pingpong=args.pingpong,
128134
persistent=persistent,
129135
is_dynamic_persistent=args.dynamic_persistent,
136+
max_swizzle_size=args.max_swizzle_size,
130137
cu_seqlens_m=cu_seqlens_m,
131138
cu_seqlens_k=cu_seqlens_k,
132139
A_idx=A_idx,
@@ -146,7 +153,8 @@ def fn():
146153
])
147154
elif varlen_k:
148155
ref = torch.stack([
149-
A[:, cu_seqlens_k[i]:cu_seqlens_k[i+1]] @
156+
(A[:, A_idx[cu_seqlens_k[i]:cu_seqlens_k[i+1]]] if gather_A
157+
else A[:, cu_seqlens_k[i]:cu_seqlens_k[i+1]]) @
150158
B[:, cu_seqlens_k[i]:cu_seqlens_k[i+1]].T
151159
for i in range(l)
152160
])

0 commit comments

Comments
 (0)