Skip to content

Commit 6576834

Browse files
danieldkdrbh
andauthored
Benchmark command fixes (#229)
* feat: upload benchmarks folder and subcommand cleanups * feat: respect versions on upload and download * fix: adjust for comat and mypy * Simplify `kernels upload` changes to avoid blocking release --------- Co-authored-by: David Holtz <david.richard.holtz@gmail.com>
1 parent f243b69 commit 6576834

7 files changed

Lines changed: 268 additions & 17 deletions

File tree

kernels/src/kernels/benchmark.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def verify_silu(self) -> torch.Tensor:
8888

8989
seed: int | None = None # Optional: seed for reproducibility
9090

91-
def __init__(self):
92-
self.kernel = None
91+
def __init__(self) -> None:
92+
self.kernel: Any = None
9393
self.out: Any = None # Output tensor, set by setup methods
9494

95-
def setup(self):
95+
def setup(self) -> None:
9696
"""Override to set up tensors as instance attributes."""
9797
pass
9898

@@ -452,16 +452,16 @@ def collect_machine_info() -> MachineInfo:
452452
)
453453

454454

455-
def get_kernel_sha_from_ops(kernel: Any) -> str:
456-
ops_name = kernel.ops.__name__
457-
# Format is torch.ops._<name>_<sha>, extract the last part after underscore
455+
def get_kernel_sha_from_build_name(kernel: Any) -> str:
456+
ops_name = kernel.__name__
457+
# Format is <name>_<sha>, extract the last part after underscore
458458
sha = ops_name.rsplit("_", 1)[-1]
459459
return sha
460460

461461

462462
def _synchronize() -> None:
463463
if torch.cuda.is_available():
464-
_synchronize()
464+
torch.cuda.synchronize()
465465
elif hasattr(torch, "xpu") and torch.xpu.is_available():
466466
torch.xpu.synchronize()
467467
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
@@ -491,7 +491,7 @@ def run_benchmark_class(
491491
from kernels import get_kernel
492492

493493
kernel = get_kernel(repo_id, revision=revision)
494-
kernel_sha = get_kernel_sha_from_ops(kernel)
494+
kernel_sha = get_kernel_sha_from_build_name(kernel)
495495

496496
for method_name in benchmark_methods:
497497
workload_name = method_name.replace("benchmark_", "")
@@ -689,8 +689,8 @@ def submit_benchmark(
689689

690690
def run_benchmark(
691691
repo_id: str,
692-
# TODO: change default to 1 in the future
693-
revision: str = "main",
692+
branch: str | None,
693+
version: int | None,
694694
iterations: int = 100,
695695
warmup: int = 10,
696696
upload: bool = False,
@@ -708,6 +708,27 @@ def run_benchmark(
708708
# Suppress progress bars for cleaner output (files are often cached)
709709
disable_progress_bars()
710710

711+
# Requires either branch or version or parses from repo_id
712+
if branch is None and version is None:
713+
if "@" not in repo_id:
714+
print("Error: must specify either branch or version", file=sys.stderr)
715+
sys.exit(1)
716+
717+
# Parse from repo_id
718+
repo_id, rev = repo_id.split("@", 1)
719+
720+
if rev.startswith("v") and rev[1:].isdigit():
721+
version = int(rev[1:])
722+
elif rev.isdigit():
723+
print("Error: version must be prefixed with 'v'", file=sys.stderr)
724+
sys.exit(1)
725+
else:
726+
branch = rev
727+
728+
# Move version or branch into revision
729+
revision = f"v{version}" if version is not None else branch
730+
assert revision is not None # Guaranteed by parsing logic above
731+
711732
print(f"Downloading {repo_id}@{revision}...", file=sys.stderr)
712733
repo_path = Path(snapshot_download(repo_id=repo_id, revision=revision))
713734

kernels/src/kernels/benchmarks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
FlashAttentionCausalBenchmark,
55
FlashAttentionVarlenBenchmark,
66
)
7+
from .layer_norm import LayerNormBenchmark, RMSNormBenchmark
78

89
__all__ = [
910
"FlashAttentionBenchmark",
1011
"FlashAttentionCausalBenchmark",
1112
"FlashAttentionVarlenBenchmark",
13+
"LayerNormBenchmark",
14+
"RMSNormBenchmark",
1215
"SiluAndMulBenchmark",
1316
]
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import torch
2+
3+
from kernels.benchmark import Benchmark
4+
5+
6+
class RMSNormBenchmark(Benchmark):
7+
seed: int = 42
8+
eps: float = 1e-5
9+
10+
# Workload: small (B=2, S=128, D=768)
11+
def setup_small(self):
12+
B, S, D = 2, 128, 768
13+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
14+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
15+
self.out = torch.empty_like(self.x)
16+
self.B, self.S, self.D = B, S, D
17+
18+
def benchmark_small(self):
19+
self.out = self.kernel.dropout_add_ln_fwd(
20+
input=self.x.view(-1, self.D),
21+
gamma=self.weight,
22+
beta=None,
23+
rowscale=None,
24+
colscale=None,
25+
x0_subset=None,
26+
z_subset=None,
27+
dropout_p=0.0,
28+
epsilon=self.eps,
29+
rowscale_const=1.0,
30+
z_numrows=self.S,
31+
gen=None,
32+
residual_in_fp32=False,
33+
is_rms_norm=True,
34+
)[0].view(self.B, self.S, self.D)
35+
36+
def verify_small(self) -> torch.Tensor:
37+
var = self.x.pow(2).mean(-1, keepdim=True)
38+
return (self.x * torch.rsqrt(var + self.eps)) * self.weight
39+
40+
# Workload: medium (B=4, S=512, D=2048)
41+
def setup_medium(self):
42+
B, S, D = 4, 512, 2048
43+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
44+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
45+
self.out = torch.empty_like(self.x)
46+
self.B, self.S, self.D = B, S, D
47+
48+
def benchmark_medium(self):
49+
self.out = self.kernel.dropout_add_ln_fwd(
50+
input=self.x.view(-1, self.D),
51+
gamma=self.weight,
52+
beta=None,
53+
rowscale=None,
54+
colscale=None,
55+
x0_subset=None,
56+
z_subset=None,
57+
dropout_p=0.0,
58+
epsilon=self.eps,
59+
rowscale_const=1.0,
60+
z_numrows=self.S,
61+
gen=None,
62+
residual_in_fp32=False,
63+
is_rms_norm=True,
64+
)[0].view(self.B, self.S, self.D)
65+
66+
def verify_medium(self) -> torch.Tensor:
67+
var = self.x.pow(2).mean(-1, keepdim=True)
68+
return (self.x * torch.rsqrt(var + self.eps)) * self.weight
69+
70+
# Workload: large (B=8, S=1024, D=4096)
71+
def setup_large(self):
72+
B, S, D = 8, 1024, 4096
73+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
74+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
75+
self.out = torch.empty_like(self.x)
76+
self.B, self.S, self.D = B, S, D
77+
78+
def benchmark_large(self):
79+
self.out = self.kernel.dropout_add_ln_fwd(
80+
input=self.x.view(-1, self.D),
81+
gamma=self.weight,
82+
beta=None,
83+
rowscale=None,
84+
colscale=None,
85+
x0_subset=None,
86+
z_subset=None,
87+
dropout_p=0.0,
88+
epsilon=self.eps,
89+
rowscale_const=1.0,
90+
z_numrows=self.S,
91+
gen=None,
92+
residual_in_fp32=False,
93+
is_rms_norm=True,
94+
)[0].view(self.B, self.S, self.D)
95+
96+
def verify_large(self) -> torch.Tensor:
97+
var = self.x.pow(2).mean(-1, keepdim=True)
98+
return (self.x * torch.rsqrt(var + self.eps)) * self.weight
99+
100+
101+
class LayerNormBenchmark(Benchmark):
102+
seed: int = 42
103+
eps: float = 1e-5
104+
105+
# Workload: small (B=2, S=128, D=768)
106+
def setup_small(self):
107+
B, S, D = 2, 128, 768
108+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
109+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
110+
self.out = torch.empty_like(self.x)
111+
self.B, self.S, self.D = B, S, D
112+
113+
def benchmark_small(self):
114+
self.out = self.kernel.dropout_add_ln_fwd(
115+
input=self.x.view(-1, self.D),
116+
gamma=self.weight,
117+
beta=None,
118+
rowscale=None,
119+
colscale=None,
120+
x0_subset=None,
121+
z_subset=None,
122+
dropout_p=0.0,
123+
epsilon=self.eps,
124+
rowscale_const=1.0,
125+
z_numrows=self.S,
126+
gen=None,
127+
residual_in_fp32=False,
128+
is_rms_norm=False,
129+
)[0].view(self.B, self.S, self.D)
130+
131+
def verify_small(self) -> torch.Tensor:
132+
return torch.nn.functional.layer_norm(
133+
self.x, [self.D], self.weight, eps=self.eps
134+
)
135+
136+
# Workload: medium (B=4, S=512, D=2048)
137+
def setup_medium(self):
138+
B, S, D = 4, 512, 2048
139+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
140+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
141+
self.out = torch.empty_like(self.x)
142+
self.B, self.S, self.D = B, S, D
143+
144+
def benchmark_medium(self):
145+
self.out = self.kernel.dropout_add_ln_fwd(
146+
input=self.x.view(-1, self.D),
147+
gamma=self.weight,
148+
beta=None,
149+
rowscale=None,
150+
colscale=None,
151+
x0_subset=None,
152+
z_subset=None,
153+
dropout_p=0.0,
154+
epsilon=self.eps,
155+
rowscale_const=1.0,
156+
z_numrows=self.S,
157+
gen=None,
158+
residual_in_fp32=False,
159+
is_rms_norm=False,
160+
)[0].view(self.B, self.S, self.D)
161+
162+
def verify_medium(self) -> torch.Tensor:
163+
return torch.nn.functional.layer_norm(
164+
self.x, [self.D], self.weight, eps=self.eps
165+
)
166+
167+
# Workload: large (B=8, S=1024, D=4096)
168+
def setup_large(self):
169+
B, S, D = 8, 1024, 4096
170+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
171+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
172+
self.out = torch.empty_like(self.x)
173+
self.B, self.S, self.D = B, S, D
174+
175+
def benchmark_large(self):
176+
self.out = self.kernel.dropout_add_ln_fwd(
177+
input=self.x.view(-1, self.D),
178+
gamma=self.weight,
179+
beta=None,
180+
rowscale=None,
181+
colscale=None,
182+
x0_subset=None,
183+
z_subset=None,
184+
dropout_p=0.0,
185+
epsilon=self.eps,
186+
rowscale_const=1.0,
187+
z_numrows=self.S,
188+
gen=None,
189+
residual_in_fp32=False,
190+
is_rms_norm=False,
191+
)[0].view(self.B, self.S, self.D)
192+
193+
def verify_large(self) -> torch.Tensor:
194+
return torch.nn.functional.layer_norm(
195+
self.x, [self.D], self.weight, eps=self.eps
196+
)

kernels/src/kernels/cli.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ def main():
126126
help="Kernel repo ID (e.g., kernels-community/activation)",
127127
)
128128
benchmark_parser.add_argument(
129-
"--revision",
130-
type=str,
131-
default="main",
132-
help="Kernel revision (default: main)",
129+
"--branch", type=str, help="Kernel branch to benchmark"
130+
)
131+
benchmark_parser.add_argument(
132+
"--version", type=int, help="Kernel version to benchmark"
133133
)
134134
benchmark_parser.add_argument(
135135
"--output",
@@ -247,7 +247,8 @@ def run_benchmark(args):
247247

248248
benchmark.run_benchmark(
249249
repo_id=args.repo_id,
250-
revision=args.revision,
250+
branch=args.branch,
251+
version=args.version,
251252
iterations=args.iterations,
252253
warmup=args.warmup,
253254
output=args.output,

kernels/src/kernels/metadata.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,27 @@
22
from dataclasses import dataclass
33
from pathlib import Path
44

5+
from kernels.compat import tomllib
6+
57

68
@dataclass
79
class Metadata:
810
python_depends: list[str]
911
version: int | None
1012

13+
@staticmethod
14+
def load_from_build_toml(build_toml_path: Path) -> "Metadata":
15+
if build_toml_path.exists():
16+
with open(build_toml_path, "rb") as f:
17+
data = tomllib.load(f)
18+
version = data.get("general", {}).get("version", None)
19+
return Metadata(
20+
version=version,
21+
python_depends=[],
22+
)
23+
24+
return Metadata(version=None, python_depends=[])
25+
1126
@staticmethod
1227
def load_from_variant(variant_path: Path) -> "Metadata":
1328
metadata_path = variant_path / "metadata.json"

kernels/src/kernels/upload.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88

99
def upload_kernels_dir(
10-
kernel_dir: Path, *, repo_id: str, branch: str | None, private: bool
10+
kernel_dir: Path,
11+
*,
12+
repo_id: str,
13+
branch: str | None,
14+
private: bool,
1115
):
1216
kernel_dir = Path(kernel_dir).resolve()
1317

@@ -54,6 +58,17 @@ def upload_kernels_dir(
5458
if build_variant.is_dir():
5559
delete_patterns.add(f"{build_variant.name}/**")
5660

61+
# in the case we have variants, upload to the same as the kernel_dir
62+
upload_folder(
63+
repo_id=repo_id,
64+
folder_path=kernel_dir / "benchmarks",
65+
revision=branch,
66+
path_in_repo="benchmarks",
67+
delete_patterns=["benchmark*.py"],
68+
commit_message="Benchmarks uploaded using `kernels`.",
69+
allow_patterns=["benchmark*.py"],
70+
)
71+
5772
upload_folder(
5873
repo_id=repo_id,
5974
folder_path=build_dir,

kernels/uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)