diff --git a/kernels/src/kernels/benchmark.py b/kernels/src/kernels/benchmark.py index 11d11d7a..609958bd 100644 --- a/kernels/src/kernels/benchmark.py +++ b/kernels/src/kernels/benchmark.py @@ -88,11 +88,11 @@ def verify_silu(self) -> torch.Tensor: seed: int | None = None # Optional: seed for reproducibility - def __init__(self): - self.kernel = None + def __init__(self) -> None: + self.kernel: Any = None self.out: Any = None # Output tensor, set by setup methods - def setup(self): + def setup(self) -> None: """Override to set up tensors as instance attributes.""" pass @@ -452,16 +452,16 @@ def collect_machine_info() -> MachineInfo: ) -def get_kernel_sha_from_ops(kernel: Any) -> str: - ops_name = kernel.ops.__name__ - # Format is torch.ops.__, extract the last part after underscore +def get_kernel_sha_from_build_name(kernel: Any) -> str: + ops_name = kernel.__name__ + # Format is _, extract the last part after underscore sha = ops_name.rsplit("_", 1)[-1] return sha def _synchronize() -> None: if torch.cuda.is_available(): - _synchronize() + torch.cuda.synchronize() elif hasattr(torch, "xpu") and torch.xpu.is_available(): torch.xpu.synchronize() elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): @@ -491,7 +491,7 @@ def run_benchmark_class( from kernels import get_kernel kernel = get_kernel(repo_id, revision=revision) - kernel_sha = get_kernel_sha_from_ops(kernel) + kernel_sha = get_kernel_sha_from_build_name(kernel) for method_name in benchmark_methods: workload_name = method_name.replace("benchmark_", "") @@ -689,8 +689,8 @@ def submit_benchmark( def run_benchmark( repo_id: str, - # TODO: change default to 1 in the future - revision: str = "main", + branch: str | None, + version: int | None, iterations: int = 100, warmup: int = 10, upload: bool = False, @@ -708,6 +708,27 @@ def run_benchmark( # Suppress progress bars for cleaner output (files are often cached) disable_progress_bars() + # Requires either branch or version or parses from repo_id + if branch is None and version is None: + if "@" not in repo_id: + print("Error: must specify either branch or version", file=sys.stderr) + sys.exit(1) + + # Parse from repo_id + repo_id, rev = repo_id.split("@", 1) + + if rev.startswith("v") and rev[1:].isdigit(): + version = int(rev[1:]) + elif rev.isdigit(): + print("Error: version must be prefixed with 'v'", file=sys.stderr) + sys.exit(1) + else: + branch = rev + + # Move version or branch into revision + revision = f"v{version}" if version is not None else branch + assert revision is not None # Guaranteed by parsing logic above + print(f"Downloading {repo_id}@{revision}...", file=sys.stderr) repo_path = Path(snapshot_download(repo_id=repo_id, revision=revision)) diff --git a/kernels/src/kernels/benchmarks/__init__.py b/kernels/src/kernels/benchmarks/__init__.py index 528cf85e..fdf89634 100644 --- a/kernels/src/kernels/benchmarks/__init__.py +++ b/kernels/src/kernels/benchmarks/__init__.py @@ -4,10 +4,13 @@ FlashAttentionCausalBenchmark, FlashAttentionVarlenBenchmark, ) +from .layer_norm import LayerNormBenchmark, RMSNormBenchmark __all__ = [ "FlashAttentionBenchmark", "FlashAttentionCausalBenchmark", "FlashAttentionVarlenBenchmark", + "LayerNormBenchmark", + "RMSNormBenchmark", "SiluAndMulBenchmark", ] diff --git a/kernels/src/kernels/benchmarks/layer_norm.py b/kernels/src/kernels/benchmarks/layer_norm.py new file mode 100644 index 00000000..7c7b319c --- /dev/null +++ b/kernels/src/kernels/benchmarks/layer_norm.py @@ -0,0 +1,196 @@ +import torch + +from kernels.benchmark import Benchmark + + +class RMSNormBenchmark(Benchmark): + seed: int = 42 + eps: float = 1e-5 + + # Workload: small (B=2, S=128, D=768) + def setup_small(self): + B, S, D = 2, 128, 768 + self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16) + self.weight = torch.ones(D, device="cuda", dtype=torch.float16) + self.out = torch.empty_like(self.x) + self.B, self.S, self.D = B, S, D + + def benchmark_small(self): + self.out = self.kernel.dropout_add_ln_fwd( + input=self.x.view(-1, self.D), + gamma=self.weight, + beta=None, + rowscale=None, + colscale=None, + x0_subset=None, + z_subset=None, + dropout_p=0.0, + epsilon=self.eps, + rowscale_const=1.0, + z_numrows=self.S, + gen=None, + residual_in_fp32=False, + is_rms_norm=True, + )[0].view(self.B, self.S, self.D) + + def verify_small(self) -> torch.Tensor: + var = self.x.pow(2).mean(-1, keepdim=True) + return (self.x * torch.rsqrt(var + self.eps)) * self.weight + + # Workload: medium (B=4, S=512, D=2048) + def setup_medium(self): + B, S, D = 4, 512, 2048 + self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16) + self.weight = torch.ones(D, device="cuda", dtype=torch.float16) + self.out = torch.empty_like(self.x) + self.B, self.S, self.D = B, S, D + + def benchmark_medium(self): + self.out = self.kernel.dropout_add_ln_fwd( + input=self.x.view(-1, self.D), + gamma=self.weight, + beta=None, + rowscale=None, + colscale=None, + x0_subset=None, + z_subset=None, + dropout_p=0.0, + epsilon=self.eps, + rowscale_const=1.0, + z_numrows=self.S, + gen=None, + residual_in_fp32=False, + is_rms_norm=True, + )[0].view(self.B, self.S, self.D) + + def verify_medium(self) -> torch.Tensor: + var = self.x.pow(2).mean(-1, keepdim=True) + return (self.x * torch.rsqrt(var + self.eps)) * self.weight + + # Workload: large (B=8, S=1024, D=4096) + def setup_large(self): + B, S, D = 8, 1024, 4096 + self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16) + self.weight = torch.ones(D, device="cuda", dtype=torch.float16) + self.out = torch.empty_like(self.x) + self.B, self.S, self.D = B, S, D + + def benchmark_large(self): + self.out = self.kernel.dropout_add_ln_fwd( + input=self.x.view(-1, self.D), + gamma=self.weight, + beta=None, + rowscale=None, + colscale=None, + x0_subset=None, + z_subset=None, + dropout_p=0.0, + epsilon=self.eps, + rowscale_const=1.0, + z_numrows=self.S, + gen=None, + residual_in_fp32=False, + is_rms_norm=True, + )[0].view(self.B, self.S, self.D) + + def verify_large(self) -> torch.Tensor: + var = self.x.pow(2).mean(-1, keepdim=True) + return (self.x * torch.rsqrt(var + self.eps)) * self.weight + + +class LayerNormBenchmark(Benchmark): + seed: int = 42 + eps: float = 1e-5 + + # Workload: small (B=2, S=128, D=768) + def setup_small(self): + B, S, D = 2, 128, 768 + self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16) + self.weight = torch.ones(D, device="cuda", dtype=torch.float16) + self.out = torch.empty_like(self.x) + self.B, self.S, self.D = B, S, D + + def benchmark_small(self): + self.out = self.kernel.dropout_add_ln_fwd( + input=self.x.view(-1, self.D), + gamma=self.weight, + beta=None, + rowscale=None, + colscale=None, + x0_subset=None, + z_subset=None, + dropout_p=0.0, + epsilon=self.eps, + rowscale_const=1.0, + z_numrows=self.S, + gen=None, + residual_in_fp32=False, + is_rms_norm=False, + )[0].view(self.B, self.S, self.D) + + def verify_small(self) -> torch.Tensor: + return torch.nn.functional.layer_norm( + self.x, [self.D], self.weight, eps=self.eps + ) + + # Workload: medium (B=4, S=512, D=2048) + def setup_medium(self): + B, S, D = 4, 512, 2048 + self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16) + self.weight = torch.ones(D, device="cuda", dtype=torch.float16) + self.out = torch.empty_like(self.x) + self.B, self.S, self.D = B, S, D + + def benchmark_medium(self): + self.out = self.kernel.dropout_add_ln_fwd( + input=self.x.view(-1, self.D), + gamma=self.weight, + beta=None, + rowscale=None, + colscale=None, + x0_subset=None, + z_subset=None, + dropout_p=0.0, + epsilon=self.eps, + rowscale_const=1.0, + z_numrows=self.S, + gen=None, + residual_in_fp32=False, + is_rms_norm=False, + )[0].view(self.B, self.S, self.D) + + def verify_medium(self) -> torch.Tensor: + return torch.nn.functional.layer_norm( + self.x, [self.D], self.weight, eps=self.eps + ) + + # Workload: large (B=8, S=1024, D=4096) + def setup_large(self): + B, S, D = 8, 1024, 4096 + self.x = torch.randn(B, S, D, device="cuda", dtype=torch.float16) + self.weight = torch.ones(D, device="cuda", dtype=torch.float16) + self.out = torch.empty_like(self.x) + self.B, self.S, self.D = B, S, D + + def benchmark_large(self): + self.out = self.kernel.dropout_add_ln_fwd( + input=self.x.view(-1, self.D), + gamma=self.weight, + beta=None, + rowscale=None, + colscale=None, + x0_subset=None, + z_subset=None, + dropout_p=0.0, + epsilon=self.eps, + rowscale_const=1.0, + z_numrows=self.S, + gen=None, + residual_in_fp32=False, + is_rms_norm=False, + )[0].view(self.B, self.S, self.D) + + def verify_large(self) -> torch.Tensor: + return torch.nn.functional.layer_norm( + self.x, [self.D], self.weight, eps=self.eps + ) diff --git a/kernels/src/kernels/cli.py b/kernels/src/kernels/cli.py index dd128b48..9a08a8d8 100644 --- a/kernels/src/kernels/cli.py +++ b/kernels/src/kernels/cli.py @@ -84,6 +84,11 @@ def main(): action="store_true", help="If the repository should be private.", ) + upload_parser.add_argument( + "--benchmarks-only", + action="store_true", + help="If set, only upload the benchmarks directory.", + ) upload_parser.set_defaults(func=upload_kernels) lock_parser = subparsers.add_parser("lock", help="Lock kernel revisions") @@ -126,10 +131,10 @@ def main(): help="Kernel repo ID (e.g., kernels-community/activation)", ) benchmark_parser.add_argument( - "--revision", - type=str, - default="main", - help="Kernel revision (default: main)", + "--branch", type=str, help="Kernel branch to benchmark" + ) + benchmark_parser.add_argument( + "--version", type=int, help="Kernel version to benchmark" ) benchmark_parser.add_argument( "--output", @@ -211,6 +216,7 @@ def upload_kernels(args): repo_id=args.repo_id, branch=args.branch, private=args.private, + benchmarks_only=args.benchmarks_only, ) @@ -247,7 +253,8 @@ def run_benchmark(args): benchmark.run_benchmark( repo_id=args.repo_id, - revision=args.revision, + branch=args.branch, + version=args.version, iterations=args.iterations, warmup=args.warmup, output=args.output, diff --git a/kernels/src/kernels/metadata.py b/kernels/src/kernels/metadata.py index ccea1964..1c50abc4 100644 --- a/kernels/src/kernels/metadata.py +++ b/kernels/src/kernels/metadata.py @@ -2,12 +2,27 @@ from dataclasses import dataclass from pathlib import Path +from kernels.compat import tomllib + @dataclass class Metadata: python_depends: list[str] version: int | None + @staticmethod + def load_from_build_toml(build_toml_path: Path) -> "Metadata": + if build_toml_path.exists(): + with open(build_toml_path, "rb") as f: + data = tomllib.load(f) + version = data.get("general", {}).get("version", None) + return Metadata( + version=version, + python_depends=[], + ) + + return Metadata(version=None, python_depends=[]) + @staticmethod def load_from_variant(variant_path: Path) -> "Metadata": metadata_path = variant_path / "metadata.json" diff --git a/kernels/src/kernels/upload.py b/kernels/src/kernels/upload.py index 77738fff..a3aa2ea2 100644 --- a/kernels/src/kernels/upload.py +++ b/kernels/src/kernels/upload.py @@ -7,10 +7,46 @@ def upload_kernels_dir( - kernel_dir: Path, *, repo_id: str, branch: str | None, private: bool + kernel_dir: Path, + *, + repo_id: str, + branch: str | None, + private: bool, + benchmarks_only: bool, ): kernel_dir = Path(kernel_dir).resolve() + repo_id = create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id + + if branch is not None: + create_branch(repo_id=repo_id, branch=branch, exist_ok=True) + + # if uploading benchmarks read the version from build.toml + # as to not depend on build variants + if benchmarks_only: + metadata = Metadata.load_from_build_toml(kernel_dir / "build.toml") + if metadata.version is None: + raise ValueError( + f"Cannot upload benchmarks only without a version specified in build.toml at: {(kernel_dir / 'build.toml').absolute()}" + ) + + branch = f"v{metadata.version}" + + upload_folder( + repo_id=repo_id, + folder_path=kernel_dir / "benchmarks", + revision=branch, + path_in_repo="benchmarks", + delete_patterns=["benchmark*.py"], + commit_message="Benchmarks uploaded using `kernels`.", + allow_patterns=["benchmark*.py"], + ) + + print( + f"✅ Benchmarks upload successful. Find the kernel in: https://hf.co/{repo_id}" + ) + return # Exit if only benchmarks are to be uploaded + build_dir = None variants = None for candidate in [kernel_dir / "build", kernel_dir]: @@ -43,17 +79,24 @@ def upload_kernels_dir( version = versions.pop() if version is not None: branch = f"v{version}" - - repo_id = create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id - - if branch is not None: - create_branch(repo_id=repo_id, branch=branch, exist_ok=True) + create_branch(repo_id=repo_id, branch=branch, exist_ok=True) delete_patterns: set[str] = set() for build_variant in build_dir.iterdir(): if build_variant.is_dir(): delete_patterns.add(f"{build_variant.name}/**") + # in the case we have variants, upload to the same as the kernel_dir + upload_folder( + repo_id=repo_id, + folder_path=kernel_dir / "benchmarks", + revision=branch, + path_in_repo="benchmarks", + delete_patterns=["benchmark*.py"], + commit_message="Benchmarks uploaded using `kernels`.", + allow_patterns=["benchmark*.py"], + ) + upload_folder( repo_id=repo_id, folder_path=build_dir, @@ -63,4 +106,6 @@ def upload_kernels_dir( commit_message="Build uploaded using `kernels`.", allow_patterns=["torch*"], ) - print(f"✅ Kernel upload successful. Find the kernel in: https://hf.co/{repo_id}") + print( + f"✅ Kernel and benchmarks upload successful. Find the kernel in: https://hf.co/{repo_id}" + ) diff --git a/kernels/uv.lock b/kernels/uv.lock index 43c75809..5ade9666 100644 --- a/kernels/uv.lock +++ b/kernels/uv.lock @@ -670,7 +670,7 @@ wheels = [ [[package]] name = "kernels" -version = "0.11.6.dev0" +version = "0.11.8.dev0" source = { editable = "." } dependencies = [ { name = "huggingface-hub" },