Skip to content

Commit d833c84

Browse files
shijiashuaiCopilot
andcommitted
fix(ci): format python examples
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 7605688 commit d833c84

2 files changed

Lines changed: 81 additions & 49 deletions

File tree

examples/python/basic_usage.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
"Build the bindings first:\n"
1616
" cmake -S . -B build -DBUILD_PYTHON_BINDINGS=ON\n"
1717
" cmake --build build\n"
18-
" export PYTHONPATH=\"$(pwd)/build/python:${PYTHONPATH}\"\n"
18+
' export PYTHONPATH="$(pwd)/build/python:${PYTHONPATH}"\n'
1919
) from exc
2020

2121

2222
def require_cuda() -> torch.device:
2323
if not torch.cuda.is_available():
24-
raise SystemExit("Error: this example requires a CUDA-enabled PyTorch installation.")
24+
raise SystemExit(
25+
"Error: this example requires a CUDA-enabled PyTorch installation."
26+
)
2527
return torch.device("cuda")
2628

2729

@@ -50,7 +52,9 @@ def example_reduction(device: torch.device) -> None:
5052
x = torch.randn(64, 128, device=device, dtype=torch.float32)
5153
softmax_out = torch.empty_like(x)
5254
opt.reduction.softmax(x, softmax_out, x.shape[0], x.shape[1])
53-
torch.testing.assert_close(softmax_out, torch.softmax(x, dim=-1), rtol=1e-5, atol=1e-5)
55+
torch.testing.assert_close(
56+
softmax_out, torch.softmax(x, dim=-1), rtol=1e-5, atol=1e-5
57+
)
5458
print("Softmax passed")
5559

5660

python/benchmark/benchmark.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121
# Optional imports for visualization
2222
try:
2323
import matplotlib.pyplot as plt
24+
2425
HAS_MATPLOTLIB = True
2526
except ImportError:
2627
HAS_MATPLOTLIB = False
2728

2829
try:
2930
import numpy as np
31+
3032
HAS_NUMPY = True
3133
except ImportError:
3234
HAS_NUMPY = False
@@ -35,6 +37,7 @@
3537
@dataclass
3638
class BenchmarkResult:
3739
"""Container for benchmark results."""
40+
3841
kernel: str
3942
hpc_ms: float
4043
baseline_ms: float
@@ -48,6 +51,7 @@ class BenchmarkResult:
4851
@dataclass
4952
class DeviceInfo:
5053
"""GPU device information."""
54+
5155
name: str
5256
compute_capability: Tuple[int, int]
5357
total_memory_gb: float
@@ -99,7 +103,7 @@ def benchmark_kernel(
99103
min_run_time: float = 1.0,
100104
flops: Optional[int] = None,
101105
bytes_accessed: Optional[int] = None,
102-
**kwargs
106+
**kwargs,
103107
) -> BenchmarkResult:
104108
"""
105109
Compare HPC kernel with baseline implementation.
@@ -128,14 +132,14 @@ def benchmark_kernel(
128132
# Benchmark HPC kernel
129133
hpc_timer = Timer(
130134
stmt="hpc_fn(*args, **kwargs)",
131-
globals={"hpc_fn": hpc_fn, "args": args, "kwargs": kwargs}
135+
globals={"hpc_fn": hpc_fn, "args": args, "kwargs": kwargs},
132136
)
133137
hpc_result = hpc_timer.blocked_autorange(min_run_time=min_run_time)
134138

135139
# Benchmark baseline
136140
baseline_timer = Timer(
137141
stmt="baseline_fn(*args, **kwargs)",
138-
globals={"baseline_fn": baseline_fn, "args": args, "kwargs": kwargs}
142+
globals={"baseline_fn": baseline_fn, "args": args, "kwargs": kwargs},
139143
)
140144
baseline_result = baseline_timer.blocked_autorange(min_run_time=min_run_time)
141145

@@ -201,7 +205,9 @@ def analyze(self, result: BenchmarkResult) -> Dict[str, Any]:
201205
achieved_tflops = result.tflops
202206

203207
# Ridge point: where compute and memory rooflines meet
204-
ridge_point = self.device_info.peak_fp32_tflops / self.device_info.peak_bandwidth_gb_s
208+
ridge_point = (
209+
self.device_info.peak_fp32_tflops / self.device_info.peak_bandwidth_gb_s
210+
)
205211

206212
# Determine bottleneck
207213
if ai < ridge_point:
@@ -228,7 +234,7 @@ def plot_roofline(
228234
self,
229235
results: List[BenchmarkResult],
230236
output_path: str = "roofline.png",
231-
title: str = "Roofline Analysis"
237+
title: str = "Roofline Analysis",
232238
):
233239
"""Generate roofline plot for multiple kernels."""
234240
if not HAS_MATPLOTLIB or not HAS_NUMPY:
@@ -252,9 +258,11 @@ def plot_roofline(
252258
roofline = np.minimum(memory_roof, compute_roof)
253259

254260
# Plot roofline
255-
ax.loglog(ai_range, roofline, 'b-', linewidth=2, label='Roofline')
256-
ax.loglog(ai_range, memory_roof, 'b--', alpha=0.5, label='Memory Bound')
257-
ax.axhline(y=peak_compute, color='b', linestyle=':', alpha=0.5, label='Compute Bound')
261+
ax.loglog(ai_range, roofline, "b-", linewidth=2, label="Roofline")
262+
ax.loglog(ai_range, memory_roof, "b--", alpha=0.5, label="Memory Bound")
263+
ax.axhline(
264+
y=peak_compute, color="b", linestyle=":", alpha=0.5, label="Compute Bound"
265+
)
258266

259267
# Plot kernel results
260268
colors = plt.cm.tab10(np.linspace(0, 1, len(results)))
@@ -265,25 +273,25 @@ def plot_roofline(
265273
result.tflops,
266274
s=200,
267275
c=[color],
268-
marker='o',
276+
marker="o",
269277
label=result.kernel,
270-
zorder=5
278+
zorder=5,
271279
)
272280

273281
# Ridge point
274282
ridge_point = peak_compute / peak_bandwidth
275-
ax.axvline(x=ridge_point, color='gray', linestyle='--', alpha=0.5)
283+
ax.axvline(x=ridge_point, color="gray", linestyle="--", alpha=0.5)
276284
ax.annotate(
277-
f'Ridge Point\n({ridge_point:.1f} FLOP/B)',
285+
f"Ridge Point\n({ridge_point:.1f} FLOP/B)",
278286
xy=(ridge_point, peak_compute * 0.5),
279287
fontsize=9,
280-
ha='center'
288+
ha="center",
281289
)
282290

283-
ax.set_xlabel('Arithmetic Intensity (FLOP/Byte)', fontsize=12)
284-
ax.set_ylabel('Performance (TFLOPS)', fontsize=12)
285-
ax.set_title(f'{title}\n{self.device_info.name}', fontsize=14)
286-
ax.legend(loc='lower right')
291+
ax.set_xlabel("Arithmetic Intensity (FLOP/Byte)", fontsize=12)
292+
ax.set_ylabel("Performance (TFLOPS)", fontsize=12)
293+
ax.set_title(f"{title}\n{self.device_info.name}", fontsize=14)
294+
ax.legend(loc="lower right")
287295
ax.grid(True, alpha=0.3)
288296
ax.set_xlim(0.01, 10000)
289297
ax.set_ylim(0.01, peak_compute * 2)
@@ -294,13 +302,17 @@ def plot_roofline(
294302
print(f"Roofline plot saved to {output_path}")
295303

296304

297-
def print_results(results: List[BenchmarkResult], device_info: Optional[DeviceInfo] = None):
305+
def print_results(
306+
results: List[BenchmarkResult], device_info: Optional[DeviceInfo] = None
307+
):
298308
"""Print benchmark results in a formatted table."""
299309
print("\n" + "=" * 90)
300310
if device_info:
301311
print(f"Device: {device_info.name}")
302-
print(f"Peak FP32: {device_info.peak_fp32_tflops:.1f} TFLOPS | "
303-
f"Peak Bandwidth: {device_info.peak_bandwidth_gb_s:.0f} GB/s")
312+
print(
313+
f"Peak FP32: {device_info.peak_fp32_tflops:.1f} TFLOPS | "
314+
f"Peak Bandwidth: {device_info.peak_bandwidth_gb_s:.0f} GB/s"
315+
)
304316
print("=" * 90)
305317

306318
header = f"{'Kernel':<25} {'HPC (ms)':<10} {'Base (ms)':<10} {'Speedup':<10}"
@@ -325,7 +337,7 @@ def generate_html_report(
325337
results: List[BenchmarkResult],
326338
device_info: DeviceInfo,
327339
output_path: str = "benchmark_report.html",
328-
roofline_image: Optional[str] = None
340+
roofline_image: Optional[str] = None,
329341
):
330342
"""Generate HTML benchmark report."""
331343
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -415,15 +427,15 @@ def generate_html_report(
415427
</html>
416428
"""
417429

418-
with open(output_path, 'w') as f:
430+
with open(output_path, "w") as f:
419431
f.write(html)
420432
print(f"HTML report saved to {output_path}")
421433

422434

423435
def plot_speedup_chart(
424436
results: List[BenchmarkResult],
425437
output_path: str = "speedup_chart.png",
426-
title: str = "Kernel Speedup vs Baseline"
438+
title: str = "Kernel Speedup vs Baseline",
427439
):
428440
"""Generate speedup bar chart."""
429441
if not HAS_MATPLOTLIB:
@@ -434,30 +446,30 @@ def plot_speedup_chart(
434446

435447
kernels = [r.kernel for r in results]
436448
speedups = [r.speedup for r in results]
437-
colors = ['#4CAF50' if s >= 1.0 else '#f44336' for s in speedups]
449+
colors = ["#4CAF50" if s >= 1.0 else "#f44336" for s in speedups]
438450

439451
bars = ax.bar(kernels, speedups, color=colors)
440452

441453
# Add value labels
442454
for bar, speedup in zip(bars, speedups):
443455
height = bar.get_height()
444456
ax.annotate(
445-
f'{speedup:.2f}x',
457+
f"{speedup:.2f}x",
446458
xy=(bar.get_x() + bar.get_width() / 2, height),
447459
xytext=(0, 3),
448460
textcoords="offset points",
449-
ha='center',
450-
va='bottom',
451-
fontsize=10
461+
ha="center",
462+
va="bottom",
463+
fontsize=10,
452464
)
453465

454-
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.7, label='Baseline')
455-
ax.set_xlabel('Kernel', fontsize=12)
456-
ax.set_ylabel('Speedup', fontsize=12)
466+
ax.axhline(y=1.0, color="gray", linestyle="--", alpha=0.7, label="Baseline")
467+
ax.set_xlabel("Kernel", fontsize=12)
468+
ax.set_ylabel("Speedup", fontsize=12)
457469
ax.set_title(title, fontsize=14)
458470
ax.legend()
459471

460-
plt.xticks(rotation=45, ha='right')
472+
plt.xticks(rotation=45, ha="right")
461473
plt.tight_layout()
462474
plt.savefig(output_path, dpi=150)
463475
plt.close()
@@ -474,16 +486,26 @@ def main():
474486
python benchmark.py --suite gemm --sizes 1024,2048,4096
475487
python benchmark.py --suite all --output results.json --html report.html
476488
python benchmark.py --roofline --output roofline.png
477-
"""
489+
""",
490+
)
491+
parser.add_argument(
492+
"--suite",
493+
type=str,
494+
default="all",
495+
choices=["all", "gemm", "elementwise", "reduction", "attention"],
496+
help="Benchmark suite to run",
497+
)
498+
parser.add_argument(
499+
"--sizes",
500+
type=str,
501+
default="1024,2048,4096",
502+
help="Comma-separated list of sizes to benchmark",
478503
)
479-
parser.add_argument("--suite", type=str, default="all",
480-
choices=["all", "gemm", "elementwise", "reduction", "attention"],
481-
help="Benchmark suite to run")
482-
parser.add_argument("--sizes", type=str, default="1024,2048,4096",
483-
help="Comma-separated list of sizes to benchmark")
484504
parser.add_argument("--output", type=str, help="Output JSON file for results")
485505
parser.add_argument("--html", type=str, help="Output HTML report file")
486-
parser.add_argument("--roofline", action="store_true", help="Generate roofline plot")
506+
parser.add_argument(
507+
"--roofline", action="store_true", help="Generate roofline plot"
508+
)
487509
parser.add_argument("--chart", action="store_true", help="Generate speedup chart")
488510
args = parser.parse_args()
489511

@@ -494,7 +516,9 @@ def main():
494516
# Get device info
495517
device_info = get_device_info()
496518
print(f"\nDevice: {device_info.name}")
497-
print(f"Compute Capability: {device_info.compute_capability[0]}.{device_info.compute_capability[1]}")
519+
print(
520+
f"Compute Capability: {device_info.compute_capability[0]}.{device_info.compute_capability[1]}"
521+
)
498522
print(f"Peak FP32: {device_info.peak_fp32_tflops:.1f} TFLOPS")
499523
print(f"Peak Bandwidth: {device_info.peak_bandwidth_gb_s:.0f} GB/s")
500524

@@ -515,12 +539,16 @@ def main():
515539
print_results(results, device_info)
516540

517541
if args.output:
518-
with open(args.output, 'w') as f:
519-
json.dump({
520-
"device": asdict(device_info),
521-
"results": [asdict(r) for r in results],
522-
"timestamp": datetime.now().isoformat()
523-
}, f, indent=2)
542+
with open(args.output, "w") as f:
543+
json.dump(
544+
{
545+
"device": asdict(device_info),
546+
"results": [asdict(r) for r in results],
547+
"timestamp": datetime.now().isoformat(),
548+
},
549+
f,
550+
indent=2,
551+
)
524552
print(f"Results saved to {args.output}")
525553

526554
if args.html:

0 commit comments

Comments
 (0)