Skip to content

Commit 7858dca

Browse files
committed
feat: upload benchmarks folder and subcommand cleanups
1 parent 7912260 commit 7858dca

4 files changed

Lines changed: 240 additions & 12 deletions

File tree

kernels/src/kernels/benchmark.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -452,16 +452,16 @@ def collect_machine_info() -> MachineInfo:
452452
)
453453

454454

455-
def get_kernel_sha_from_ops(kernel: Any) -> str:
456-
ops_name = kernel.ops.__name__
457-
# Format is torch.ops._<name>_<sha>, extract the last part after underscore
455+
def get_kernel_sha_from_build_name(kernel: Any) -> str:
456+
ops_name = kernel.__name__
457+
# Format is <name>_<sha>, extract the last part after underscore
458458
sha = ops_name.rsplit("_", 1)[-1]
459459
return sha
460460

461461

462462
def _synchronize() -> None:
463463
if torch.cuda.is_available():
464-
_synchronize()
464+
torch.cuda.synchronize()
465465
elif hasattr(torch, "xpu") and torch.xpu.is_available():
466466
torch.xpu.synchronize()
467467
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
@@ -491,7 +491,7 @@ def run_benchmark_class(
491491
from kernels import get_kernel
492492

493493
kernel = get_kernel(repo_id, revision=revision)
494-
kernel_sha = get_kernel_sha_from_ops(kernel)
494+
kernel_sha = get_kernel_sha_from_build_name(kernel)
495495

496496
for method_name in benchmark_methods:
497497
workload_name = method_name.replace("benchmark_", "")
@@ -659,6 +659,8 @@ def run_benchmark_script(
659659
)
660660
for name, timing in results.items():
661661
all_results[f"{cls.__name__}.{name}"] = timing
662+
663+
print("test")
662664
return all_results, kernel_sha
663665

664666

kernels/src/kernels/benchmarks/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
FlashAttentionCausalBenchmark,
55
FlashAttentionVarlenBenchmark,
66
)
7+
from .layer_norm import LayerNormBenchmark, RMSNormBenchmark
78

89
__all__ = [
910
"FlashAttentionBenchmark",
1011
"FlashAttentionCausalBenchmark",
1112
"FlashAttentionVarlenBenchmark",
13+
"LayerNormBenchmark",
14+
"RMSNormBenchmark",
1215
"SiluAndMulBenchmark",
1316
]
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import torch
2+
3+
from kernels.benchmark import Benchmark
4+
5+
6+
class RMSNormBenchmark(Benchmark):
7+
seed: int = 42
8+
eps: float = 1e-5
9+
10+
# Workload: small (B=2, S=128, D=768)
11+
def setup_small(self):
12+
B, S, D = 2, 128, 768
13+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
14+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
15+
self.out = torch.empty_like(self.x)
16+
self.B, self.S, self.D = B, S, D
17+
18+
def benchmark_small(self):
19+
self.out = self.kernel.dropout_add_ln_fwd(
20+
input=self.x.view(-1, self.D),
21+
gamma=self.weight,
22+
beta=None,
23+
rowscale=None,
24+
colscale=None,
25+
x0_subset=None,
26+
z_subset=None,
27+
dropout_p=0.0,
28+
epsilon=self.eps,
29+
rowscale_const=1.0,
30+
z_numrows=self.S,
31+
gen=None,
32+
residual_in_fp32=False,
33+
is_rms_norm=True,
34+
)[0].view(self.B, self.S, self.D)
35+
36+
def verify_small(self) -> torch.Tensor:
37+
var = self.x.pow(2).mean(-1, keepdim=True)
38+
return (self.x * torch.rsqrt(var + self.eps)) * self.weight
39+
40+
# Workload: medium (B=4, S=512, D=2048)
41+
def setup_medium(self):
42+
B, S, D = 4, 512, 2048
43+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
44+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
45+
self.out = torch.empty_like(self.x)
46+
self.B, self.S, self.D = B, S, D
47+
48+
def benchmark_medium(self):
49+
self.out = self.kernel.dropout_add_ln_fwd(
50+
input=self.x.view(-1, self.D),
51+
gamma=self.weight,
52+
beta=None,
53+
rowscale=None,
54+
colscale=None,
55+
x0_subset=None,
56+
z_subset=None,
57+
dropout_p=0.0,
58+
epsilon=self.eps,
59+
rowscale_const=1.0,
60+
z_numrows=self.S,
61+
gen=None,
62+
residual_in_fp32=False,
63+
is_rms_norm=True,
64+
)[0].view(self.B, self.S, self.D)
65+
66+
def verify_medium(self) -> torch.Tensor:
67+
var = self.x.pow(2).mean(-1, keepdim=True)
68+
return (self.x * torch.rsqrt(var + self.eps)) * self.weight
69+
70+
# Workload: large (B=8, S=1024, D=4096)
71+
def setup_large(self):
72+
B, S, D = 8, 1024, 4096
73+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
74+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
75+
self.out = torch.empty_like(self.x)
76+
self.B, self.S, self.D = B, S, D
77+
78+
def benchmark_large(self):
79+
self.out = self.kernel.dropout_add_ln_fwd(
80+
input=self.x.view(-1, self.D),
81+
gamma=self.weight,
82+
beta=None,
83+
rowscale=None,
84+
colscale=None,
85+
x0_subset=None,
86+
z_subset=None,
87+
dropout_p=0.0,
88+
epsilon=self.eps,
89+
rowscale_const=1.0,
90+
z_numrows=self.S,
91+
gen=None,
92+
residual_in_fp32=False,
93+
is_rms_norm=True,
94+
)[0].view(self.B, self.S, self.D)
95+
96+
def verify_large(self) -> torch.Tensor:
97+
var = self.x.pow(2).mean(-1, keepdim=True)
98+
return (self.x * torch.rsqrt(var + self.eps)) * self.weight
99+
100+
101+
class LayerNormBenchmark(Benchmark):
102+
seed: int = 42
103+
eps: float = 1e-5
104+
105+
# Workload: small (B=2, S=128, D=768)
106+
def setup_small(self):
107+
B, S, D = 2, 128, 768
108+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
109+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
110+
self.out = torch.empty_like(self.x)
111+
self.B, self.S, self.D = B, S, D
112+
113+
def benchmark_small(self):
114+
self.out = self.kernel.dropout_add_ln_fwd(
115+
input=self.x.view(-1, self.D),
116+
gamma=self.weight,
117+
beta=None,
118+
rowscale=None,
119+
colscale=None,
120+
x0_subset=None,
121+
z_subset=None,
122+
dropout_p=0.0,
123+
epsilon=self.eps,
124+
rowscale_const=1.0,
125+
z_numrows=self.S,
126+
gen=None,
127+
residual_in_fp32=False,
128+
is_rms_norm=False,
129+
)[0].view(self.B, self.S, self.D)
130+
131+
def verify_small(self) -> torch.Tensor:
132+
return torch.nn.functional.layer_norm(
133+
self.x, [self.D], self.weight, eps=self.eps
134+
)
135+
136+
# Workload: medium (B=4, S=512, D=2048)
137+
def setup_medium(self):
138+
B, S, D = 4, 512, 2048
139+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
140+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
141+
self.out = torch.empty_like(self.x)
142+
self.B, self.S, self.D = B, S, D
143+
144+
def benchmark_medium(self):
145+
self.out = self.kernel.dropout_add_ln_fwd(
146+
input=self.x.view(-1, self.D),
147+
gamma=self.weight,
148+
beta=None,
149+
rowscale=None,
150+
colscale=None,
151+
x0_subset=None,
152+
z_subset=None,
153+
dropout_p=0.0,
154+
epsilon=self.eps,
155+
rowscale_const=1.0,
156+
z_numrows=self.S,
157+
gen=None,
158+
residual_in_fp32=False,
159+
is_rms_norm=False,
160+
)[0].view(self.B, self.S, self.D)
161+
162+
def verify_medium(self) -> torch.Tensor:
163+
return torch.nn.functional.layer_norm(
164+
self.x, [self.D], self.weight, eps=self.eps
165+
)
166+
167+
# Workload: large (B=8, S=1024, D=4096)
168+
def setup_large(self):
169+
B, S, D = 8, 1024, 4096
170+
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
171+
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
172+
self.out = torch.empty_like(self.x)
173+
self.B, self.S, self.D = B, S, D
174+
175+
def benchmark_large(self):
176+
self.out = self.kernel.dropout_add_ln_fwd(
177+
input=self.x.view(-1, self.D),
178+
gamma=self.weight,
179+
beta=None,
180+
rowscale=None,
181+
colscale=None,
182+
x0_subset=None,
183+
z_subset=None,
184+
dropout_p=0.0,
185+
epsilon=self.eps,
186+
rowscale_const=1.0,
187+
z_numrows=self.S,
188+
gen=None,
189+
residual_in_fp32=False,
190+
is_rms_norm=False,
191+
)[0].view(self.B, self.S, self.D)
192+
193+
def verify_large(self) -> torch.Tensor:
194+
return torch.nn.functional.layer_norm(
195+
self.x, [self.D], self.weight, eps=self.eps
196+
)

kernels/src/kernels/cli.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,18 @@ def main():
8181
action="store_true",
8282
help="If the repository should be private.",
8383
)
84+
# by default dont include benchmarks, but enable with flag or
85+
# only upload benchmarks with separate flag
86+
upload_parser.add_argument(
87+
"--benchmarks",
88+
action="store_true",
89+
help="If set, upload both benchmarks and build variants (default).",
90+
)
91+
upload_parser.add_argument(
92+
"--benchmarks-only",
93+
action="store_true",
94+
help="If set, only upload the benchmarks directory.",
95+
)
8496
upload_parser.set_defaults(func=upload_kernels)
8597

8698
lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
@@ -202,6 +214,28 @@ def upload_kernels(args):
202214
# Resolve `kernel_dir` to be uploaded.
203215
kernel_dir = Path(args.kernel_dir).resolve()
204216

217+
repo_id = create_repo(
218+
repo_id=args.repo_id, private=args.private, exist_ok=True
219+
).repo_id
220+
221+
if args.branch is not None:
222+
create_branch(repo_id=repo_id, branch=args.branch, exist_ok=True)
223+
224+
# benchmarks directory upload (doesn't require build variants)
225+
if args.benchmarks or args.benchmarks_only:
226+
upload_folder(
227+
repo_id=repo_id,
228+
folder_path=kernel_dir / "benchmarks",
229+
revision=args.branch,
230+
path_in_repo="benchmarks",
231+
delete_patterns=["benchmark*.py"],
232+
commit_message="Benchmarks uploaded using `kernels`.",
233+
allow_patterns=["benchmark*.py"],
234+
)
235+
236+
if args.benchmarks_only:
237+
return # Exit if only benchmarks are to be uploaded
238+
205239
build_dir = None
206240
for candidate in [kernel_dir / "build", kernel_dir]:
207241
variants = [
@@ -217,13 +251,6 @@ def upload_kernels(args):
217251
f"Couldn't find any build variants in: {kernel_dir.absolute()} or {(kernel_dir / 'build').absolute()}"
218252
)
219253

220-
repo_id = create_repo(
221-
repo_id=args.repo_id, private=args.private, exist_ok=True
222-
).repo_id
223-
224-
if args.branch is not None:
225-
create_branch(repo_id=repo_id, branch=args.branch, exist_ok=True)
226-
227254
delete_patterns: set[str] = set()
228255
for build_variant in build_dir.iterdir():
229256
if build_variant.is_dir():

0 commit comments

Comments
 (0)