Skip to content

Commit cc1bb09

Browse files
shijiashuaiqwencoder
andcommitted
fix: format all Python files with ruff
Format remaining files that CI checks: - python/profiler.py - tests/conftest.py - tests/test_interface.py - tests/test_profiler.py - benchmarks/benchmark_attention.py - benchmarks/benchmark_gemm.py Co-authored-by: Qwen-Coder <qwen-coder@alibabacloud.com>
1 parent 2837f06 commit cc1bb09

7 files changed

Lines changed: 444 additions & 134 deletions

File tree

benchmarks/benchmark_attention.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def benchmark_attention(
7676
print(f"\nBenchmarking seq_len={seq_len}...")
7777

7878
# Create inputs
79-
q = torch.randn(
80-
batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=dtype
81-
)
79+
q = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=dtype)
8280
k = torch.randn_like(q)
8381
v = torch.randn_like(q)
8482

@@ -107,9 +105,7 @@ def pytorch_attention(q, k, v):
107105
if has_custom:
108106
# Naive attention
109107
try:
110-
naive_time = benchmark_kernel(
111-
naive_attention, q, k, v, warmup, iterations
112-
)
108+
naive_time = benchmark_kernel(naive_attention, q, k, v, warmup, iterations)
113109
result["naive_ms"] = naive_time
114110
result["naive_tflops"] = (flops / 1e12) / (naive_time / 1000)
115111
result["naive_speedup"] = pytorch_time / naive_time
@@ -119,9 +115,7 @@ def pytorch_attention(q, k, v):
119115

120116
# Tiled attention
121117
try:
122-
tiled_time = benchmark_kernel(
123-
tiled_attention, q, k, v, warmup, iterations
124-
)
118+
tiled_time = benchmark_kernel(tiled_attention, q, k, v, warmup, iterations)
125119
result["tiled_ms"] = tiled_time
126120
result["tiled_tflops"] = (flops / 1e12) / (tiled_time / 1000)
127121
result["tiled_speedup"] = pytorch_time / tiled_time
@@ -131,9 +125,7 @@ def pytorch_attention(q, k, v):
131125

132126
# Flash attention
133127
try:
134-
flash_time = benchmark_kernel(
135-
flash_attention, q, k, v, warmup, iterations
136-
)
128+
flash_time = benchmark_kernel(flash_attention, q, k, v, warmup, iterations)
137129
result["flash_ms"] = flash_time
138130
result["flash_tflops"] = (flops / 1e12) / (flash_time / 1000)
139131
result["flash_speedup"] = pytorch_time / flash_time
@@ -143,9 +135,9 @@ def pytorch_attention(q, k, v):
143135

144136
# Record peak GPU memory
145137
result["peak_memory_mb"] = torch.cuda.max_memory_allocated() / (1024 * 1024)
146-
result["input_memory_mb"] = (
147-
mem_before + q.nelement() * q.element_size() * 3
148-
) / (1024 * 1024)
138+
result["input_memory_mb"] = (mem_before + q.nelement() * q.element_size() * 3) / (
139+
1024 * 1024
140+
)
149141

150142
results.append(result)
151143

@@ -162,9 +154,7 @@ def print_results(results: List[Dict]):
162154
print(
163155
f"\n{'Seq Len':>8} | {'PyTorch':>10} | {'Naive':>10} | {'Tiled':>10} | {'Flash':>10} | {'Best Speedup':>12}"
164156
)
165-
print(
166-
f"{'':>8} | {'(ms)':>10} | {'(ms)':>10} | {'(ms)':>10} | {'(ms)':>10} | {'':>12}"
167-
)
157+
print(f"{'':>8} | {'(ms)':>10} | {'(ms)':>10} | {'(ms)':>10} | {'(ms)':>10} | {'':>12}")
168158
print("-" * 80)
169159

170160
for r in results:
@@ -184,12 +174,8 @@ def print_results(results: List[Dict]):
184174
print("TFLOPS COMPARISON")
185175
print("=" * 80)
186176

187-
print(
188-
f"\n{'Seq Len':>8} | {'PyTorch':>12} | {'Naive':>12} | {'Tiled':>12} | {'Flash':>12}"
189-
)
190-
print(
191-
f"{'':>8} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12}"
192-
)
177+
print(f"\n{'Seq Len':>8} | {'PyTorch':>12} | {'Naive':>12} | {'Tiled':>12} | {'Flash':>12}")
178+
print(f"{'':>8} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12}")
193179
print("-" * 80)
194180

195181
for r in results:
@@ -213,17 +199,13 @@ def main():
213199
help="Sequence lengths to benchmark",
214200
)
215201
parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
216-
parser.add_argument(
217-
"--num-heads", type=int, default=32, help="Number of attention heads"
218-
)
202+
parser.add_argument("--num-heads", type=int, default=32, help="Number of attention heads")
219203
parser.add_argument("--head-dim", type=int, default=128, help="Head dimension")
220204
parser.add_argument(
221205
"--dtype", type=str, default="fp16", choices=["fp16", "fp32"], help="Data type"
222206
)
223207
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
224-
parser.add_argument(
225-
"--iterations", type=int, default=100, help="Benchmark iterations"
226-
)
208+
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
227209
parser.add_argument(
228210
"--output", type=str, default=None, help="Output JSON file path for results"
229211
)

benchmarks/benchmark_gemm.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ def cublas_gemm(a, b):
103103
custom_time = benchmark_kernel(gemm, a, b, warmup, iterations)
104104
result["custom_ms"] = custom_time
105105
result["custom_tflops"] = (flops / 1e12) / (custom_time / 1000)
106-
result["custom_relative"] = (
107-
result["custom_tflops"] / result["cublas_tflops"]
108-
)
106+
result["custom_relative"] = result["custom_tflops"] / result["cublas_tflops"]
109107
except Exception as e:
110108
print(f" Custom GEMM failed: {e}")
111109
result["custom_ms"] = float("inf")
@@ -114,9 +112,7 @@ def cublas_gemm(a, b):
114112
# Tensor Core GEMM (FP16 only)
115113
if dtype == torch.float16:
116114
try:
117-
tc_time = benchmark_kernel(
118-
tensor_core_gemm, a, b, warmup, iterations
119-
)
115+
tc_time = benchmark_kernel(tensor_core_gemm, a, b, warmup, iterations)
120116
result["tensor_core_ms"] = tc_time
121117
result["tensor_core_tflops"] = (flops / 1e12) / (tc_time / 1000)
122118
result["tensor_core_relative"] = (
@@ -170,9 +166,7 @@ def print_results(results: List[Dict]):
170166
print("=" * 100)
171167

172168
print(f"\n{'Size':>20} | {'cuBLAS':>12} | {'Custom':>12} | {'TC GEMM':>12}")
173-
print(
174-
f"{'(M x N x K)':>20} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12}"
175-
)
169+
print(f"{'(M x N x K)':>20} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12} | {'(TFLOPS)':>12}")
176170
print("-" * 100)
177171

178172
for r in results:
@@ -183,22 +177,16 @@ def print_results(results: List[Dict]):
183177

184178
tc_str = f"{tc_tflops:.2f}" if tc_tflops > 0 else "N/A"
185179

186-
print(
187-
f"{size_str:>20} | {cublas_tflops:>12.2f} | {custom_tflops:>12.2f} | {tc_str:>12}"
188-
)
180+
print(f"{size_str:>20} | {cublas_tflops:>12.2f} | {custom_tflops:>12.2f} | {tc_str:>12}")
189181

190182
# Summary
191183
print("\n" + "=" * 100)
192184
print("SUMMARY")
193185
print("=" * 100)
194186

195-
avg_custom_rel = (
196-
sum(r.get("custom_relative", 0) for r in results) / len(results) * 100
197-
)
187+
avg_custom_rel = sum(r.get("custom_relative", 0) for r in results) / len(results) * 100
198188
avg_tc_rel = sum(
199-
r.get("tensor_core_relative", 0)
200-
for r in results
201-
if r.get("tensor_core_relative", 0) > 0
189+
r.get("tensor_core_relative", 0) for r in results if r.get("tensor_core_relative", 0) > 0
202190
)
203191
tc_count = sum(1 for r in results if r.get("tensor_core_relative", 0) > 0)
204192
if tc_count > 0:
@@ -231,9 +219,7 @@ def main():
231219
"--dtype", type=str, default="fp16", choices=["fp16", "fp32"], help="Data type"
232220
)
233221
parser.add_argument("--warmup", type=int, default=10, help="Warmup iterations")
234-
parser.add_argument(
235-
"--iterations", type=int, default=100, help="Benchmark iterations"
236-
)
222+
parser.add_argument("--iterations", type=int, default=100, help="Benchmark iterations")
237223
parser.add_argument(
238224
"--output", type=str, default=None, help="Output JSON file path for results"
239225
)

python/bindings.cpp

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ namespace py = pybind11;
1010

1111
// Forward declarations
1212
void naive_attention_fp32(const float*, const float*, const float*, float*,
13-
int, int, int, int, float, cudaStream_t);
13+
int, int, int, int, float, bool, cudaStream_t);
1414
void naive_attention_fp16(const half*, const half*, const half*, half*,
15-
int, int, int, int, float, cudaStream_t);
15+
int, int, int, int, float, bool, cudaStream_t);
1616
void tiled_attention_fp32(const float*, const float*, const float*, float*,
17-
int, int, int, int, float, cudaStream_t);
17+
int, int, int, int, float, bool, cudaStream_t);
1818
void tiled_attention_fp16(const half*, const half*, const half*, half*,
19-
int, int, int, int, float, cudaStream_t);
19+
int, int, int, int, float, bool, cudaStream_t);
2020
void flash_attention_fp32(const float*, const float*, const float*, float*,
2121
int, int, int, int, float, bool, cudaStream_t);
2222
void flash_attention_fp16(const half*, const half*, const half*, half*,
@@ -96,30 +96,31 @@ torch::Tensor naive_attention(
9696
const torch::Tensor& q,
9797
const torch::Tensor& k,
9898
const torch::Tensor& v,
99-
float scale = 0.0f
99+
float scale = 0.0f,
100+
bool is_causal = false
100101
) {
101102
validate_attention_inputs(q, k, v);
102-
103+
103104
int batch_size = q.size(0);
104105
int num_heads = q.size(1);
105106
int seq_len = q.size(2);
106107
int head_dim = q.size(3);
107-
108+
108109
if (scale == 0.0f) {
109110
scale = 1.0f / sqrtf(static_cast<float>(head_dim));
110111
}
111-
112+
112113
auto output = torch::empty_like(q);
113114
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
114-
115+
115116
if (q.scalar_type() == torch::kFloat32) {
116117
naive_attention_fp32(
117118
q.data_ptr<float>(),
118119
k.data_ptr<float>(),
119120
v.data_ptr<float>(),
120121
output.data_ptr<float>(),
121122
batch_size, num_heads, seq_len, head_dim,
122-
scale, stream
123+
scale, is_causal, stream
123124
);
124125
} else {
125126
naive_attention_fp16(
@@ -128,10 +129,10 @@ torch::Tensor naive_attention(
128129
reinterpret_cast<const half*>(v.data_ptr<at::Half>()),
129130
reinterpret_cast<half*>(output.data_ptr<at::Half>()),
130131
batch_size, num_heads, seq_len, head_dim,
131-
scale, stream
132+
scale, is_causal, stream
132133
);
133134
}
134-
135+
135136
return output;
136137
}
137138

@@ -140,30 +141,31 @@ torch::Tensor tiled_attention(
140141
const torch::Tensor& q,
141142
const torch::Tensor& k,
142143
const torch::Tensor& v,
143-
float scale = 0.0f
144+
float scale = 0.0f,
145+
bool is_causal = false
144146
) {
145147
validate_attention_inputs(q, k, v);
146-
148+
147149
int batch_size = q.size(0);
148150
int num_heads = q.size(1);
149151
int seq_len = q.size(2);
150152
int head_dim = q.size(3);
151-
153+
152154
if (scale == 0.0f) {
153155
scale = 1.0f / sqrtf(static_cast<float>(head_dim));
154156
}
155-
157+
156158
auto output = torch::empty_like(q);
157159
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
158-
160+
159161
if (q.scalar_type() == torch::kFloat32) {
160162
tiled_attention_fp32(
161163
q.data_ptr<float>(),
162164
k.data_ptr<float>(),
163165
v.data_ptr<float>(),
164166
output.data_ptr<float>(),
165167
batch_size, num_heads, seq_len, head_dim,
166-
scale, stream
168+
scale, is_causal, stream
167169
);
168170
} else {
169171
tiled_attention_fp16(
@@ -172,10 +174,10 @@ torch::Tensor tiled_attention(
172174
reinterpret_cast<const half*>(v.data_ptr<at::Half>()),
173175
reinterpret_cast<half*>(output.data_ptr<at::Half>()),
174176
batch_size, num_heads, seq_len, head_dim,
175-
scale, stream
177+
scale, is_causal, stream
176178
);
177179
}
178-
180+
179181
return output;
180182
}
181183

@@ -331,33 +333,35 @@ torch::Tensor tensor_core_gemm_int8_wrapper(
331333

332334
PYBIND11_MODULE(cuda_llm_ops, m) {
333335
m.doc() = "CUDA LLM Kernel Optimization - High-performance attention and GEMM kernels";
334-
336+
335337
// Attention functions
336338
m.def("naive_attention", &naive_attention,
337-
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("scale") = 0.0f,
339+
py::arg("q"), py::arg("k"), py::arg("v"),
340+
py::arg("scale") = 0.0f, py::arg("is_causal") = false,
338341
"Naive attention implementation (baseline)");
339-
342+
340343
m.def("tiled_attention", &tiled_attention,
341-
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("scale") = 0.0f,
344+
py::arg("q"), py::arg("k"), py::arg("v"),
345+
py::arg("scale") = 0.0f, py::arg("is_causal") = false,
342346
"Tiled attention with shared memory optimization");
343-
347+
344348
m.def("flash_attention", &flash_attention,
345-
py::arg("q"), py::arg("k"), py::arg("v"),
349+
py::arg("q"), py::arg("k"), py::arg("v"),
346350
py::arg("scale") = 0.0f, py::arg("is_causal") = false,
347351
"FlashAttention with online softmax");
348-
352+
349353
// GEMM functions
350354
m.def("gemm", &gemm,
351355
py::arg("a"), py::arg("b"),
352356
py::arg("alpha") = 1.0f, py::arg("beta") = 0.0f,
353357
py::arg("trans_a") = false, py::arg("trans_b") = false,
354358
"High-performance GEMM with register tiling");
355-
359+
356360
m.def("tensor_core_gemm", &tensor_core_gemm,
357361
py::arg("a"), py::arg("b"),
358362
py::arg("alpha") = 1.0f, py::arg("beta") = 0.0f,
359363
"Tensor Core GEMM (FP16 input, FP32 output)");
360-
364+
361365
m.def("tensor_core_gemm_int8", &tensor_core_gemm_int8_wrapper,
362366
py::arg("a"), py::arg("b"),
363367
"Tensor Core GEMM (INT8 input, INT32 output, requires Turing+ SM>=7.2)");

python/profiler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,12 @@ def profile_attention(
6969
) -> KernelMetrics:
7070
"""Profile attention kernel and compute metrics."""
7171
# Create inputs
72-
q = torch.randn(
73-
batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=dtype
74-
)
72+
q = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=dtype)
7573
k = torch.randn_like(q)
7674
v = torch.randn_like(q)
7775

7876
# Measure time
79-
elapsed_ms = self.measure_time(
80-
func, q, k, v, warmup=warmup, iterations=iterations
81-
)
77+
elapsed_ms = self.measure_time(func, q, k, v, warmup=warmup, iterations=iterations)
8278

8379
# Compute FLOPs
8480
# Attention: 2 * batch * heads * seq^2 * head_dim (Q@K^T)

tests/conftest.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,10 @@ def random_seed():
4545
def attention_inputs(device):
4646
"""Generate random attention inputs."""
4747

48-
def _generate(
49-
batch_size=2, num_heads=4, seq_len=64, head_dim=32, dtype=torch.float32
50-
):
51-
q = torch.randn(
52-
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
53-
)
54-
k = torch.randn(
55-
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
56-
)
57-
v = torch.randn(
58-
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
59-
)
48+
def _generate(batch_size=2, num_heads=4, seq_len=64, head_dim=32, dtype=torch.float32):
49+
q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
50+
k = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
51+
v = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype)
6052
return q, k, v
6153

6254
return _generate
@@ -81,8 +73,7 @@ def assert_close(actual, expected, rtol=1e-3, atol=1e-3, msg=""):
8173
max_diff = diff.max().item()
8274
mean_diff = diff.mean().item()
8375
raise AssertionError(
84-
f"{msg}\nMax diff: {max_diff}, Mean diff: {mean_diff}, "
85-
f"rtol: {rtol}, atol: {atol}"
76+
f"{msg}\nMax diff: {max_diff}, Mean diff: {mean_diff}, rtol: {rtol}, atol: {atol}"
8677
)
8778

8879

0 commit comments

Comments
 (0)