Skip to content

Commit dc6ff64

Browse files
authored
Merge branch 'main' into cleanup-huggingface-hub-integration
2 parents 1fe909f + 6ce5b33 commit dc6ff64

2 files changed

Lines changed: 81 additions & 9 deletions

File tree

kernels/src/kernels/benchmark.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import subprocess
88
import sys
99
import time
10+
import warnings
1011
from dataclasses import dataclass
1112
from pathlib import Path
1213
from typing import Any
@@ -471,6 +472,7 @@ def run_benchmark_class(
471472
iterations: int,
472473
warmup: int,
473474
repo_id: str,
475+
is_local: bool,
474476
revision: str,
475477
) -> tuple[dict[str, TimingResults], str]:
476478
results = {}
@@ -486,9 +488,13 @@ def run_benchmark_class(
486488
raise RuntimeError(f"No benchmark_* methods found in {benchmark_cls.__name__}")
487489

488490
# Load kernel once for all workloads
489-
from kernels import get_kernel
491+
from kernels import get_local_kernel, get_kernel
492+
493+
if is_local:
494+
kernel = get_local_kernel(Path(repo_id), "activation")
495+
else:
496+
kernel = get_kernel(repo_id, revision=revision)
490497

491-
kernel = get_kernel(repo_id, revision=revision)
492498
kernel_sha = get_kernel_sha_from_build_name(kernel)
493499
backend_name = backend() if TORCH_AVAILABLE else "cpu"
494500
# Map backend names to torch device names
@@ -647,6 +653,7 @@ def run_benchmark_script(
647653
warmup: int,
648654
cwd: Path,
649655
repo_id: str,
656+
is_local: bool,
650657
revision: str,
651658
) -> tuple[dict[str, TimingResults], str]:
652659
print(f"Running {script_path.name}...", file=sys.stderr)
@@ -674,6 +681,7 @@ def run_benchmark_script(
674681
iterations=iterations,
675682
warmup=warmup,
676683
repo_id=repo_id,
684+
is_local=is_local,
677685
revision=revision,
678686
)
679687
for name, timing in results.items():
@@ -717,6 +725,24 @@ def run_benchmark(
717725
# Suppress progress bars for cleaner output (files are often cached)
718726
disable_progress_bars()
719727

728+
repo_id_path = Path(repo_id)
729+
730+
if repo_id_path.is_absolute():
731+
is_local = repo_id_path.exists()
732+
else:
733+
is_local = (Path.cwd() / repo_id_path).exists()
734+
repo_id_path = Path.cwd() / repo_id_path
735+
736+
if is_local:
737+
if repo_id.count("/") == 1 and not repo_id.startswith(("./", "../")):
738+
warnings.warn(
739+
f"'{repo_id}' exists locally but looks like a repo_id. "
740+
f"Use './{repo_id}' to be explicit.",
741+
stacklevel=2,
742+
)
743+
branch = "local"
744+
version = None
745+
720746
# Requires either branch or version or parses from repo_id
721747
if branch is None and version is None:
722748
if "@" not in repo_id:
@@ -739,7 +765,11 @@ def run_benchmark(
739765
assert revision is not None # Guaranteed by parsing logic above
740766

741767
print(f"Downloading {repo_id}@{revision}...", file=sys.stderr)
742-
repo_path = Path(str(_get_hf_api().snapshot_download(repo_id=repo_id, revision=revision)))
768+
769+
if is_local:
770+
repo_path = repo_id_path.resolve()
771+
else:
772+
repo_path = Path(str(_get_hf_api().snapshot_download(repo_id=repo_id, revision=revision)))
743773

744774
scripts = discover_benchmark_scripts(repo_id, repo_path)
745775

@@ -753,6 +783,7 @@ def run_benchmark(
753783
warmup=warmup,
754784
cwd=repo_path,
755785
repo_id=repo_id,
786+
is_local=is_local,
756787
revision=revision,
757788
)
758789
timing_results.update(results)

kernels/src/kernels/benchmarks/activation.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ class SiluAndMulBenchmark(Benchmark):
99

1010
# Workload: small
1111
def setup_small(self):
12-
self.x = torch.randn(1, 128, 512, device="cuda", dtype=torch.float16)
13-
self.out = torch.empty(1, 128, 256, device="cuda", dtype=torch.float16)
12+
self.x = torch.randn(8, 1024, 2048, device=self.device, dtype=torch.float16)
13+
self.out = torch.empty(8, 1024, 1024, device=self.device, dtype=torch.float16)
1414

1515
def benchmark_small(self):
1616
self.kernel.silu_and_mul(self.out, self.x)
@@ -21,8 +21,8 @@ def verify_small(self) -> torch.Tensor:
2121

2222
# Workload: medium
2323
def setup_medium(self):
24-
self.x = torch.randn(4, 512, 1024, device="cuda", dtype=torch.float16)
25-
self.out = torch.empty(4, 512, 512, device="cuda", dtype=torch.float16)
24+
self.x = torch.randn(8, 2048, 4096, device=self.device, dtype=torch.float16)
25+
self.out = torch.empty(8, 2048, 2048, device=self.device, dtype=torch.float16)
2626

2727
def benchmark_medium(self):
2828
self.kernel.silu_and_mul(self.out, self.x)
@@ -33,12 +33,53 @@ def verify_medium(self) -> torch.Tensor:
3333

3434
# Workload: large
3535
def setup_large(self):
36-
self.x = torch.randn(8, 1024, 2048, device="cuda", dtype=torch.float16)
37-
self.out = torch.empty(8, 1024, 1024, device="cuda", dtype=torch.float16)
36+
self.x = torch.randn(8, 4096, 8192, device=self.device, dtype=torch.float16)
37+
self.out = torch.empty(8, 4096, 4096, device=self.device, dtype=torch.float16)
3838

3939
def benchmark_large(self):
4040
self.kernel.silu_and_mul(self.out, self.x)
41+
self.kernel.silu_and_mul(self.out, self.x)
4142

4243
def verify_large(self) -> torch.Tensor:
4344
d = self.x.shape[-1] // 2
4445
return F.silu(self.x[..., :d]) * self.x[..., d:]
46+
47+
48+
class GeluAndMulBenchmark(Benchmark):
49+
seed: int = 42
50+
51+
# Workload: small
52+
def setup_small(self):
53+
self.x = torch.randn(8, 1024, 2048, device=self.device, dtype=torch.float16)
54+
self.out = torch.empty(8, 1024, 1024, device=self.device, dtype=torch.float16)
55+
56+
def benchmark_small(self):
57+
self.kernel.gelu_and_mul(self.out, self.x)
58+
59+
def verify_small(self) -> torch.Tensor:
60+
d = self.x.shape[-1] // 2
61+
return F.gelu(self.x[..., :d]) * self.x[..., d:]
62+
63+
# Workload: medium
64+
def setup_medium(self):
65+
self.x = torch.randn(8, 2048, 4096, device=self.device, dtype=torch.float16)
66+
self.out = torch.empty(8, 2048, 2048, device=self.device, dtype=torch.float16)
67+
68+
def benchmark_medium(self):
69+
self.kernel.gelu_and_mul(self.out, self.x)
70+
71+
def verify_medium(self) -> torch.Tensor:
72+
d = self.x.shape[-1] // 2
73+
return F.gelu(self.x[..., :d]) * self.x[..., d:]
74+
75+
# Workload: large
76+
def setup_large(self):
77+
self.x = torch.randn(8, 4096, 8192, device=self.device, dtype=torch.float16)
78+
self.out = torch.empty(8, 4096, 4096, device=self.device, dtype=torch.float16)
79+
80+
def benchmark_large(self):
81+
self.kernel.gelu_and_mul(self.out, self.x)
82+
83+
def verify_large(self) -> torch.Tensor:
84+
d = self.x.shape[-1] // 2
85+
return F.gelu(self.x[..., :d]) * self.x[..., d:]

0 commit comments

Comments
 (0)