Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions kernels/src/kernels/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def verify_silu(self) -> torch.Tensor:

seed: int | None = None # Optional: seed for reproducibility

def __init__(self):
self.kernel = None
def __init__(self) -> None:
self.kernel: Any = None
self.out: Any = None # Output tensor, set by setup methods

def setup(self):
def setup(self) -> None:
"""Override to set up tensors as instance attributes."""
pass

Expand Down Expand Up @@ -452,16 +452,16 @@ def collect_machine_info() -> MachineInfo:
)


def get_kernel_sha_from_ops(kernel: Any) -> str:
ops_name = kernel.ops.__name__
# Format is torch.ops._<name>_<sha>, extract the last part after underscore
def get_kernel_sha_from_build_name(kernel: Any) -> str:
ops_name = kernel.__name__
# Format is <name>_<sha>, extract the last part after underscore
sha = ops_name.rsplit("_", 1)[-1]
return sha


def _synchronize() -> None:
if torch.cuda.is_available():
_synchronize()
torch.cuda.synchronize()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.synchronize()
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
Expand Down Expand Up @@ -491,7 +491,7 @@ def run_benchmark_class(
from kernels import get_kernel

kernel = get_kernel(repo_id, revision=revision)
kernel_sha = get_kernel_sha_from_ops(kernel)
kernel_sha = get_kernel_sha_from_build_name(kernel)

for method_name in benchmark_methods:
workload_name = method_name.replace("benchmark_", "")
Expand Down Expand Up @@ -689,8 +689,8 @@ def submit_benchmark(

def run_benchmark(
repo_id: str,
# TODO: change default to 1 in the future
revision: str = "main",
branch: str | None,
version: int | None,
iterations: int = 100,
warmup: int = 10,
upload: bool = False,
Expand All @@ -708,6 +708,27 @@ def run_benchmark(
# Suppress progress bars for cleaner output (files are often cached)
disable_progress_bars()

# Requires either branch or version or parses from repo_id
if branch is None and version is None:
if "@" not in repo_id:
print("Error: must specify either branch or version", file=sys.stderr)
sys.exit(1)

# Parse from repo_id
repo_id, rev = repo_id.split("@", 1)

if rev.startswith("v") and rev[1:].isdigit():
version = int(rev[1:])
elif rev.isdigit():
print("Error: version must be prefixed with 'v'", file=sys.stderr)
sys.exit(1)
else:
branch = rev

# Move version or branch into revision
revision = f"v{version}" if version is not None else branch
assert revision is not None # Guaranteed by parsing logic above

print(f"Downloading {repo_id}@{revision}...", file=sys.stderr)
repo_path = Path(snapshot_download(repo_id=repo_id, revision=revision))

Expand Down
3 changes: 3 additions & 0 deletions kernels/src/kernels/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
FlashAttentionCausalBenchmark,
FlashAttentionVarlenBenchmark,
)
from .layer_norm import LayerNormBenchmark, RMSNormBenchmark

__all__ = [
"FlashAttentionBenchmark",
"FlashAttentionCausalBenchmark",
"FlashAttentionVarlenBenchmark",
"LayerNormBenchmark",
"RMSNormBenchmark",
"SiluAndMulBenchmark",
]
196 changes: 196 additions & 0 deletions kernels/src/kernels/benchmarks/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import torch

from kernels.benchmark import Benchmark


class RMSNormBenchmark(Benchmark):
seed: int = 42
eps: float = 1e-5

# Workload: small (B=2, S=128, D=768)
def setup_small(self):
B, S, D = 2, 128, 768
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
self.out = torch.empty_like(self.x)
self.B, self.S, self.D = B, S, D

def benchmark_small(self):
self.out = self.kernel.dropout_add_ln_fwd(
input=self.x.view(-1, self.D),
gamma=self.weight,
beta=None,
rowscale=None,
colscale=None,
x0_subset=None,
z_subset=None,
dropout_p=0.0,
epsilon=self.eps,
rowscale_const=1.0,
z_numrows=self.S,
gen=None,
residual_in_fp32=False,
is_rms_norm=True,
)[0].view(self.B, self.S, self.D)

def verify_small(self) -> torch.Tensor:
var = self.x.pow(2).mean(-1, keepdim=True)
return (self.x * torch.rsqrt(var + self.eps)) * self.weight

# Workload: medium (B=4, S=512, D=2048)
def setup_medium(self):
B, S, D = 4, 512, 2048
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
self.out = torch.empty_like(self.x)
self.B, self.S, self.D = B, S, D

def benchmark_medium(self):
self.out = self.kernel.dropout_add_ln_fwd(
input=self.x.view(-1, self.D),
gamma=self.weight,
beta=None,
rowscale=None,
colscale=None,
x0_subset=None,
z_subset=None,
dropout_p=0.0,
epsilon=self.eps,
rowscale_const=1.0,
z_numrows=self.S,
gen=None,
residual_in_fp32=False,
is_rms_norm=True,
)[0].view(self.B, self.S, self.D)

def verify_medium(self) -> torch.Tensor:
var = self.x.pow(2).mean(-1, keepdim=True)
return (self.x * torch.rsqrt(var + self.eps)) * self.weight

# Workload: large (B=8, S=1024, D=4096)
def setup_large(self):
B, S, D = 8, 1024, 4096
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
self.out = torch.empty_like(self.x)
self.B, self.S, self.D = B, S, D

def benchmark_large(self):
self.out = self.kernel.dropout_add_ln_fwd(
input=self.x.view(-1, self.D),
gamma=self.weight,
beta=None,
rowscale=None,
colscale=None,
x0_subset=None,
z_subset=None,
dropout_p=0.0,
epsilon=self.eps,
rowscale_const=1.0,
z_numrows=self.S,
gen=None,
residual_in_fp32=False,
is_rms_norm=True,
)[0].view(self.B, self.S, self.D)

def verify_large(self) -> torch.Tensor:
var = self.x.pow(2).mean(-1, keepdim=True)
return (self.x * torch.rsqrt(var + self.eps)) * self.weight


class LayerNormBenchmark(Benchmark):
seed: int = 42
eps: float = 1e-5

# Workload: small (B=2, S=128, D=768)
def setup_small(self):
B, S, D = 2, 128, 768
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
self.out = torch.empty_like(self.x)
self.B, self.S, self.D = B, S, D

def benchmark_small(self):
self.out = self.kernel.dropout_add_ln_fwd(
input=self.x.view(-1, self.D),
gamma=self.weight,
beta=None,
rowscale=None,
colscale=None,
x0_subset=None,
z_subset=None,
dropout_p=0.0,
epsilon=self.eps,
rowscale_const=1.0,
z_numrows=self.S,
gen=None,
residual_in_fp32=False,
is_rms_norm=False,
)[0].view(self.B, self.S, self.D)

def verify_small(self) -> torch.Tensor:
return torch.nn.functional.layer_norm(
self.x, [self.D], self.weight, eps=self.eps
)

# Workload: medium (B=4, S=512, D=2048)
def setup_medium(self):
B, S, D = 4, 512, 2048
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
self.out = torch.empty_like(self.x)
self.B, self.S, self.D = B, S, D

def benchmark_medium(self):
self.out = self.kernel.dropout_add_ln_fwd(
input=self.x.view(-1, self.D),
gamma=self.weight,
beta=None,
rowscale=None,
colscale=None,
x0_subset=None,
z_subset=None,
dropout_p=0.0,
epsilon=self.eps,
rowscale_const=1.0,
z_numrows=self.S,
gen=None,
residual_in_fp32=False,
is_rms_norm=False,
)[0].view(self.B, self.S, self.D)

def verify_medium(self) -> torch.Tensor:
return torch.nn.functional.layer_norm(
self.x, [self.D], self.weight, eps=self.eps
)

# Workload: large (B=8, S=1024, D=4096)
def setup_large(self):
B, S, D = 8, 1024, 4096
self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16)
self.weight = torch.ones(D, device="cuda", dtype=torch.float16)
self.out = torch.empty_like(self.x)
self.B, self.S, self.D = B, S, D

def benchmark_large(self):
self.out = self.kernel.dropout_add_ln_fwd(
input=self.x.view(-1, self.D),
gamma=self.weight,
beta=None,
rowscale=None,
colscale=None,
x0_subset=None,
z_subset=None,
dropout_p=0.0,
epsilon=self.eps,
rowscale_const=1.0,
z_numrows=self.S,
gen=None,
residual_in_fp32=False,
is_rms_norm=False,
)[0].view(self.B, self.S, self.D)

def verify_large(self) -> torch.Tensor:
return torch.nn.functional.layer_norm(
self.x, [self.D], self.weight, eps=self.eps
)
17 changes: 12 additions & 5 deletions kernels/src/kernels/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def main():
action="store_true",
help="If the repository should be private.",
)
upload_parser.add_argument(
"--benchmarks-only",
action="store_true",
help="If set, only upload the benchmarks directory.",
)
upload_parser.set_defaults(func=upload_kernels)

lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions")
Expand Down Expand Up @@ -126,10 +131,10 @@ def main():
help="Kernel repo ID (e.g., kernels-community/activation)",
)
benchmark_parser.add_argument(
"--revision",
type=str,
default="main",
help="Kernel revision (default: main)",
"--branch", type=str, help="Kernel branch to benchmark"
)
benchmark_parser.add_argument(
"--version", type=int, help="Kernel version to benchmark"
)
benchmark_parser.add_argument(
"--output",
Expand Down Expand Up @@ -211,6 +216,7 @@ def upload_kernels(args):
repo_id=args.repo_id,
branch=args.branch,
private=args.private,
benchmarks_only=args.benchmarks_only,
)


Expand Down Expand Up @@ -247,7 +253,8 @@ def run_benchmark(args):

benchmark.run_benchmark(
repo_id=args.repo_id,
revision=args.revision,
branch=args.branch,
version=args.version,
iterations=args.iterations,
warmup=args.warmup,
output=args.output,
Expand Down
15 changes: 15 additions & 0 deletions kernels/src/kernels/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,27 @@
from dataclasses import dataclass
from pathlib import Path

from kernels.compat import tomllib


@dataclass
class Metadata:
python_depends: list[str]
version: int | None

@staticmethod
def load_from_build_toml(build_toml_path: Path) -> "Metadata":
if build_toml_path.exists():
with open(build_toml_path, "rb") as f:
data = tomllib.load(f)
version = data.get("general", {}).get("version", None)
return Metadata(
version=version,
python_depends=[],
)

return Metadata(version=None, python_depends=[])

@staticmethod
def load_from_variant(variant_path: Path) -> "Metadata":
metadata_path = variant_path / "metadata.json"
Expand Down
Loading
Loading