diff --git a/add.py b/add.py deleted file mode 100644 index 9b34f89..0000000 --- a/add.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - -if __name__ == "__main__": - torch.manual_seed(0) - - size = 98432 - dtype = torch.float16 - device = "cuda" - - input = torch.randn(size, dtype=dtype, device=device) - other = torch.randn(size, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.add(input, other) - torch_output = input + other - triton_output = ops.triton.torch.add(input, other) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["size"], - x_vals=[2**i for i in range(18, 28)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="add-performance", - args={}, - ) - ) - def benchmark(size, provider): - input = torch.randn(size, dtype=dtype, device=device) - other = torch.randn(size, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.add(input, other) - torch_output = torch.add(input, other) - triton_output = ops.triton.torch.add(input, other) - - assert torch.allclose(ninetoothed_output, torch_output) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - - if provider == "ninetoothed": - ms = triton.testing.do_bench( - lambda: ops.ninetoothed.torch.add(input, other) - ) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch.add(input, other)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: ops.triton.torch.add(input, other)) - - return ms - - benchmark.run(print_data=True, show_plots=True, save_path=".") diff --git a/addmm.py b/addmm.py deleted file mode 100644 index 59b21ad..0000000 --- a/addmm.py +++ /dev/null @@ -1,81 +0,0 @@ -import random - -import torch -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - -if __name__ == "__main__": - random.seed(0) - torch.manual_seed(0) - - shape = (512, 512) - dtype = torch.float16 - device = "cuda" - - input = torch.randn(shape, dtype=dtype, device=device) - mat1 = torch.randn(shape, dtype=dtype, device=device) - mat2 = torch.randn(shape, dtype=dtype, device=device) - beta = random.uniform(0, 1) - alpha = random.uniform(0, 1) - - ninetoothed_output = ops.ninetoothed.torch.addmm( - input, mat1, mat2, beta=beta, alpha=alpha - ) - torch_output = torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) - triton_output = ops.triton.torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["m", "n", "k"], - x_vals=[128 * i for i in range(2, 33)], - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="addmm-performance", - args={}, - ) - ) - def benchmark(m, n, k, provider): - input = torch.randn((m, n), dtype=dtype, device=device) - mat1 = torch.randn((m, k), dtype=dtype, device=device) - mat2 = torch.randn((k, n), dtype=dtype, device=device) - beta = random.uniform(0, 1) - alpha = random.uniform(0, 1) - - if provider == "ninetoothed": - ms = triton.testing.do_bench( - lambda: ops.ninetoothed.torch.addmm( - input, mat1, mat2, beta=beta, alpha=alpha - ) - ) - elif provider == "torch": - ms = triton.testing.do_bench( - lambda: torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha) - ) - elif provider == "triton": - ms = triton.testing.do_bench( - lambda: ops.triton.torch.addmm( - input, mat1, mat2, beta=beta, alpha=alpha - ) - ) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/bench.py b/bench.py new file mode 100644 index 0000000..1d60d12 --- /dev/null +++ b/bench.py @@ -0,0 +1,116 @@ +import torch +import triton + +DISPLAY_NAMES = { + "ninetoothed": "NineToothed", + "torch": "PyTorch", + "triton": "Triton", +} + +STYLES = [ + ("blue", "-"), + ("green", "-"), + ("orange", "-"), + ("red", "-"), + ("purple", "-"), + ("cyan", "-"), +] + + +def assert_match(impls, args, kwargs=None, tolerances=None): + """Assert that all implementations produce matching outputs. + + Same API as ``check``, but raises ``AssertionError`` on mismatch + instead of printing. Intended for use in test suites. + + :param impls: Ordered dict mapping provider name to callable. + :param args: Tuple of positional arguments. + :param kwargs: Dict of keyword arguments. + :param tolerances: Dict mapping provider name to ``torch.allclose`` kwargs. + """ + kwargs = kwargs or {} + tolerances = tolerances or {} + results = {name: fn(*args, **kwargs) for name, fn in impls.items()} + + names = list(impls) + reference_name = names[0] + reference = results[reference_name] + + for name in names[1:]: + tol = tolerances.get(name, {}) + ref_display = _display_name(reference_name) + other_display = _display_name(name) + + assert torch.allclose(reference, results[name], **tol), ( + f"{ref_display} and {other_display} outputs differ." + ) + + +def benchmark( + impls, + make_inputs, + x_names, + x_vals, + name, + benchmark_args=None, + x_log=True, + assert_correctness=True, + tolerances=None, + save_path=".", +): + """Create and run a performance benchmark. + + :param impls: Ordered dict mapping provider name to callable. + :param make_inputs: Callable returning ``(args_tuple, kwargs_dict)``. + :param x_names: List of benchmark parameter names. + :param x_vals: List of benchmark parameter values. + :param name: Operator name, used for the plot filename. + :param benchmark_args: Fixed benchmark args dict. + :param x_log: Whether to use log scale for the x-axis. + :param tolerances: Dict mapping provider name to ``torch.allclose`` kwargs. + :param assert_correctness: Whether to assert correctness at each point. + :param save_path: Directory to save plot files, or ``None`` to skip saving. + """ + providers = list(impls) + tolerances = tolerances or {} + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals, + line_arg="provider", + line_vals=providers, + line_names=[_display_name(p) for p in providers], + plot_name=f"{name}-performance", + args=benchmark_args or {}, + ylabel="ms", + x_log=x_log, + styles=[_style(i) for i in range(len(providers))], + ) + ) + def bench(provider, **params): + args, kwargs = make_inputs(**params) + + if assert_correctness: + results = {p: impls[p](*args, **kwargs) for p in providers} + reference = results[providers[0]] + + for p in providers[1:]: + tol = tolerances.get(p, {}) + assert torch.allclose(reference, results[p], **tol) + + return triton.testing.do_bench(lambda: impls[provider](*args, **kwargs)) + + bench.run(print_data=True, save_path=save_path) + + +def _display_name(name): + """Return the display name for a provider.""" + + return DISPLAY_NAMES.get(name, name) + + +def _style(index): + """Return a plot style, cycling through available options.""" + + return STYLES[index % len(STYLES)] diff --git a/bmm.py b/bmm.py deleted file mode 100644 index 3243511..0000000 --- a/bmm.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - -if __name__ == "__main__": - torch.manual_seed(0) - - batch_size, m, n, k = 4, 512, 2028, 1024 - dtype = torch.float16 - device = "cuda" - - input = torch.randn(batch_size, m, k, dtype=dtype, device=device) - other = torch.randn(batch_size, k, n, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.bmm(input, other) - torch_output = torch.bmm(input, other) - triton_output = ops.triton.torch.bmm(input, other) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["m", "n", "k"], - x_vals=[2**i for i in range(3, 13)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="bmm-performance", - args={"b": 4}, - ) - ) - def benchmark(b, m, n, k, provider): - input = torch.randn((b, m, k), dtype=dtype, device=device) - other = torch.randn((b, k, n), dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.bmm(input, other) - torch_output = torch.bmm(input, other) - triton_output = ops.triton.torch.bmm(input, other) - - assert torch.allclose(ninetoothed_output, torch_output) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - - if provider == "ninetoothed": - ms = triton.testing.do_bench( - lambda: ops.ninetoothed.torch.bmm(input, other) - ) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch.bmm(input, other)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: ops.triton.torch.bmm(input, other)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/conv2d.py b/conv2d.py deleted file mode 100644 index c7bd8bb..0000000 --- a/conv2d.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import torch.nn.functional as F -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - -if __name__ == "__main__": - torch.manual_seed(0) - - n, c, h, w = 4, 3, 224, 224 - k, _, r, s = 8, c, 3, 3 - dtype = torch.float16 - device = "cuda" - - input = torch.randn(n, c, h, w, dtype=dtype, device=device) - filter = torch.randn(k, c, r, s, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter) - torch_output = F.conv2d(input, filter) - triton_output = ops.triton.torch.conv2d(input, filter) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["n"], - x_vals=[2**i for i in range(1, 11)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="conv2d-performance", - args={}, - ) - ) - def benchmark(n, provider): - _, c, h, w = n, 512, 14, 14 - k, _, r, s = 512, c, 3, 3 - - input = torch.randn((n, c, h, w), dtype=dtype, device=device) - filter = torch.randn((k, c, r, s), dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.conv2d(input, filter) - torch_output = F.conv2d(input, filter) - triton_output = ops.triton.torch.conv2d(input, filter) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.01, rtol=0.01) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - - if provider == "ninetoothed": - ms = triton.testing.do_bench( - lambda: ops.ninetoothed.torch.conv2d(input, filter) - ) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: F.conv2d(input, filter)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: ops.triton.torch.conv2d(input, filter)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/fused_rms_norm.py b/fused_rms_norm.py deleted file mode 100644 index a32ce9b..0000000 --- a/fused_rms_norm.py +++ /dev/null @@ -1,114 +0,0 @@ -from contextlib import contextmanager - -import torch -import torch.nn as nn -import torch.nn.functional as F -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - - -class RMSNorm(nn.Module): - fused_rms_norm = None - - def __init__(self, other): - super().__init__() - - self.__dict__ = other.__dict__ - - def forward(self, x): - return type(self).fused_rms_norm(x, self.weight, self.variance_epsilon) - - -@contextmanager -def rms_norm_backend(backend_name): - def _torch_fused_rms_norm(x, w, eps): - return F.rms_norm(x, x.shape[-1:], w, eps) - - _prev_impl = RMSNorm.fused_rms_norm - - if backend_name == "ninetoothed": - impl = ops.ninetoothed.torch.fused_rms_norm - elif backend_name == "triton": - impl = ops.triton.torch.fused_rms_norm - elif backend_name == "torch": - impl = _torch_fused_rms_norm - else: - raise ValueError(f"unknown backend: `{backend_name}`") - - RMSNorm.fused_rms_norm = impl - - try: - yield - finally: - RMSNorm.fused_rms_norm = _prev_impl - - -if __name__ == "__main__": - torch.manual_seed(0) - - dtype = torch.float16 - device = "cuda" - - x = torch.randn(1151, 8192, dtype=dtype, device=device) - w = torch.randn(8192, dtype=dtype, device=device) - eps = 1e-5 - - ninetoothed_output = ops.ninetoothed.torch.fused_rms_norm(x, w, eps) - torch_output = F.rms_norm(x, x.shape[-1:], w, eps) - triton_output = ops.triton.torch.fused_rms_norm(x, w, eps) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.005): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["n"], - x_vals=[2**i for i in range(5, 15)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="fused-rms-norm-performance", - args={"m": 4096}, - ) - ) - def benchmark(m, n, provider): - x = torch.randn(m, n, dtype=dtype, device=device) - w = torch.randn(n, dtype=dtype, device=device) - eps = 1e-5 - - ninetoothed_output = ops.ninetoothed.torch.fused_rms_norm(x, w, eps) - torch_output = F.rms_norm(x, x.shape[-1:], w, eps) - triton_output = ops.triton.torch.fused_rms_norm(x, w, eps) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005) - assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.005) - - if provider == "ninetoothed": - ms = triton.testing.do_bench( - lambda: ops.ninetoothed.torch.fused_rms_norm(x, w, eps) - ) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: F.rms_norm(x, x.shape[-1:], w, eps)) - elif provider == "triton": - ms = triton.testing.do_bench( - lambda: ops.triton.torch.fused_rms_norm(x, w, eps) - ) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/infer.py b/infer.py index c5b2ab3..86f6ff5 100644 --- a/infer.py +++ b/infer.py @@ -5,15 +5,18 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from fused_rms_norm import RMSNorm, rms_norm_backend -from linear import Linear, bmm_backend -from scaled_dot_product_attention import ( +from modules import ( Attention, + Linear, + RMSNorm, + SiLU, + bmm_backend, + replace_module, + rms_norm_backend, rotary_position_embedding_backend, scaled_dot_product_attention_backend, + silu_backend, ) -from silu import SiLU, silu_backend -from utils import replace_module if __name__ == "__main__": parser = argparse.ArgumentParser( diff --git a/linear.py b/linear.py deleted file mode 100644 index 50bf179..0000000 --- a/linear.py +++ /dev/null @@ -1,41 +0,0 @@ -from contextlib import contextmanager - -import torch -import torch.nn as nn - -import ops.ninetoothed.torch - - -class Linear(nn.Module): - bmm = None - - def __init__(self, other): - super().__init__() - - self.__dict__ = other.__dict__ - - def forward(self, input): - return type(self).bmm( - input, self.weight.T.unsqueeze(0).expand(input.shape[0], -1, -1) - ) - - -@contextmanager -def bmm_backend(backend_name): - _prev_impl = Linear.bmm - - if backend_name == "ninetoothed": - impl = ops.ninetoothed.torch.bmm - elif backend_name == "triton": - impl = ops.triton.torch.bmm - elif backend_name == "torch": - impl = torch.bmm - else: - raise ValueError(f"unknown backend: `{backend_name}`") - - Linear.bmm = impl - - try: - yield - finally: - Linear.bmm = _prev_impl diff --git a/max_pool2d.py b/max_pool2d.py deleted file mode 100644 index 584ee53..0000000 --- a/max_pool2d.py +++ /dev/null @@ -1,100 +0,0 @@ -import math - -import ninetoothed -import ninetoothed.language as ntl -import torch -import torch.nn.functional as F -import triton -from ninetoothed import Symbol, Tensor - - -def arrangement(input, output): - BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True) - - WINDOW_HEIGHT = Symbol("WINDOW_HEIGHT", constexpr=True, upper_bound=16) - WINDOW_WIDTH = Symbol("WINDOW_WIDTH", constexpr=True, upper_bound=16) - - input_arranged = input.tile((1, 1, WINDOW_HEIGHT, WINDOW_WIDTH)) - input_arranged = input_arranged.ravel() - input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) - input_arranged = input_arranged.tile((BLOCK_SIZE, -1)) - - output_arranged = output.tile((1, 1, 1, 1)) - output_arranged = output_arranged.ravel() - output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) - output_arranged = output_arranged.tile((BLOCK_SIZE, -1)) - output_arranged.dtype = output_arranged.dtype.squeeze(1) - - return input_arranged, output_arranged - - -def application(input, output): - output = ntl.max(input, axis=1) # noqa: F841 - - -max_pool2d_kernel = ninetoothed.make( - arrangement, application, (Tensor(4, other=float("-inf")), Tensor(4)) -) - - -def max_pool2d(input, window_shape): - n, c, h, w = input.shape - r, s = window_shape - p = math.ceil((h - r) / r + 1) - q = math.ceil((w - s) / s + 1) - - output = torch.empty(n, c, p, q, dtype=input.dtype, device=input.device) - - max_pool2d_kernel(input, output, WINDOW_HEIGHT=r, WINDOW_WIDTH=s) - - return output - - -if __name__ == "__main__": - torch.manual_seed(0) - - input_shape = (32, 3, 64, 64) - window_shape = (3, 3) - - input = torch.randn(input_shape, dtype=torch.float16, device="cuda") - - ninetoothed_output = max_pool2d(input, window_shape) - torch_output = F.max_pool2d(input, window_shape, ceil_mode=True) - - print(ninetoothed_output) - print(torch_output) - - if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["h", "w"], - x_vals=[8 * i for i in range(2, 33)], - line_arg="provider", - line_vals=["ninetoothed", "torch"], - line_names=["NineToothed", "PyTorch"], - styles=[("blue", "-"), ("green", "-")], - ylabel="ms", - plot_name="max-pool2d-performance", - args={}, - ) - ) - def benchmark(h, w, provider): - n, c, h, w = 64, 64, h, w - r, s = 3, 3 - dtype = torch.float16 - device = "cuda" - input = torch.randn((n, c, h, w), dtype=dtype, device=device) - window_shape = (r, s) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: max_pool2d(input, window_shape)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: F.max_pool2d(input, window_shape)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/mm.py b/mm.py deleted file mode 100644 index f4d0b95..0000000 --- a/mm.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - -if __name__ == "__main__": - torch.manual_seed(0) - - shape = (512, 512) - dtype = torch.float16 - device = "cuda" - - input = torch.randn(shape, dtype=dtype, device=device) - other = torch.randn(shape, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.mm(input, other) - torch_output = torch.mm(input, other) - triton_output = ops.triton.torch.mm(input, other) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["m", "n", "k"], - x_vals=[2**i for i in range(3, 13)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="mm-performance", - args={}, - ) - ) - def benchmark(m, n, k, provider): - input = torch.randn((m, k), dtype=dtype, device=device) - other = torch.randn((k, n), dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.mm(input, other) - torch_output = torch.mm(input, other) - triton_output = ops.triton.torch.mm(input, other) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.mm(input, other)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch.mm(input, other)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: ops.triton.torch.mm(input, other)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..543259b --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,26 @@ +from modules._utils import replace_module +from modules.attention import ( + Attention, + generate_sin_and_cos_tables, + rotary_position_embedding_backend, + scaled_dot_product_attention_backend, + torch_rotary_position_embedding, +) +from modules.linear import Linear, bmm_backend +from modules.rms_norm import RMSNorm, rms_norm_backend +from modules.silu import SiLU, silu_backend + +__all__ = [ + "Attention", + "Linear", + "RMSNorm", + "SiLU", + "bmm_backend", + "generate_sin_and_cos_tables", + "replace_module", + "rms_norm_backend", + "rotary_position_embedding_backend", + "scaled_dot_product_attention_backend", + "silu_backend", + "torch_rotary_position_embedding", +] diff --git a/modules/_utils.py b/modules/_utils.py new file mode 100644 index 0000000..65fa21f --- /dev/null +++ b/modules/_utils.py @@ -0,0 +1,45 @@ +from contextlib import contextmanager + + +def replace_module(module, replacement_class): + """Recursively replace modules whose class name contains the replacement class name.""" + + for child_name, child_module in module.named_children(): + if replacement_class.__name__ not in child_module.__class__.__name__: + replace_module(child_module, replacement_class) + + continue + + replacement = replacement_class(child_module) + setattr(module, child_name, replacement) + + +def _make_backend_manager(cls, attr, impls): + """Create a context manager that switches a class attribute to a different backend. + + :param cls: The module class whose attribute will be swapped. + :param attr: The name of the class attribute to swap. + :param impls: Dict mapping backend name to callable implementation. + :return: A context manager function. + """ + + @contextmanager + def backend(backend_name): + prev = getattr(cls, attr) + setattr(cls, attr, _get_impl(backend_name, impls)) + + try: + yield + finally: + setattr(cls, attr, prev) + + return backend + + +def _get_impl(backend_name, impls): + """Return the implementation for the given backend name.""" + + if backend_name not in impls: + raise ValueError(f"Unknown backend: `{backend_name}`.") + + return impls[backend_name] diff --git a/modules/attention.py b/modules/attention.py new file mode 100644 index 0000000..dca50c8 --- /dev/null +++ b/modules/attention.py @@ -0,0 +1,132 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import repeat_kv + +import ops.ninetoothed.torch +import ops.triton.torch +from modules._utils import _make_backend_manager + + +def torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=True): + batch_size, seq_len, num_heads, emb_dim = input.shape + + assert emb_dim % 2 == 0, "The embedding dimension must be even." + + sin_table = sin_table[None, :, None, :] + cos_table = cos_table[None, :, None, :] + + if interleaved: + pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2) + input_0, input_1 = pair_wise_input[..., 0], pair_wise_input[..., 1] + input_0_rotated = input_0 * cos_table - input_1 * sin_table + input_1_rotated = input_0 * sin_table + input_1 * cos_table + + return torch.stack((input_0_rotated, input_1_rotated), dim=-1).view(input.shape) + else: + input_0 = input[..., : input.shape[-1] // 2] + input_1 = input[..., input.shape[-1] // 2 :] + input_0_rotated = input_0 * cos_table - input_1 * sin_table + input_1_rotated = input_0 * sin_table + input_1 * cos_table + + return torch.cat((input_0_rotated, input_1_rotated), dim=-1) + + +def generate_sin_and_cos_tables( + seq_len, emb_dim, base=10000, dtype=torch.float32, device="cuda" +): + assert emb_dim % 2 == 0, "The embedding dimension must be even." + + theta = base ** ( + -2 * (torch.arange(emb_dim // 2, dtype=dtype, device=device) / emb_dim) + ) + + positions = torch.arange(seq_len, dtype=dtype, device=device).unsqueeze(1) + sin_table = torch.sin(positions * theta) + cos_table = torch.cos(positions * theta) + + return sin_table, cos_table + + +class Attention(nn.Module): + scaled_dot_product_attention = None + rotary_position_embedding = None + + def __init__(self, other): + super().__init__() + self.__dict__ = other.__dict__ + + def forward( + self, + hidden_states, + position_embeddings, + attention_mask, + past_key_value, + cache_position, + **kwargs, + ): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape) + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) + + cos_table, sin_table = position_embeddings + sin_table = sin_table[0, ..., sin_table.shape[-1] // 2 :] + cos_table = cos_table[0, ..., cos_table.shape[-1] // 2 :] + + query_states = type(self).rotary_position_embedding( + query_states, sin_table, cos_table, interleaved=False + ) + key_states = type(self).rotary_position_embedding( + key_states, sin_table, cos_table, interleaved=False + ) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin_table, + "cos": cos_table, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_output = type(self).scaled_dot_product_attention( + query_states, key_states, value_states, scale=self.scaling + ) + attn_output = attn_output.transpose(1, 2) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, None + + +scaled_dot_product_attention_backend = _make_backend_manager( + Attention, + "scaled_dot_product_attention", + { + "ninetoothed": ops.ninetoothed.torch.scaled_dot_product_attention, + "triton": ops.triton.torch.scaled_dot_product_attention, + "torch": F.scaled_dot_product_attention, + }, +) + +rotary_position_embedding_backend = _make_backend_manager( + Attention, + "rotary_position_embedding", + { + "ninetoothed": ops.ninetoothed.torch.rotary_position_embedding, + "triton": ops.triton.torch.rotary_position_embedding, + "torch": torch_rotary_position_embedding, + }, +) diff --git a/modules/linear.py b/modules/linear.py new file mode 100644 index 0000000..8ab2d3a --- /dev/null +++ b/modules/linear.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +import ops.ninetoothed.torch +import ops.triton.torch +from modules._utils import _make_backend_manager + + +class Linear(nn.Module): + bmm = None + + def __init__(self, other): + super().__init__() + self.__dict__ = other.__dict__ + + def forward(self, input): + return type(self).bmm( + input, self.weight.T.unsqueeze(0).expand(input.shape[0], -1, -1) + ) + + +bmm_backend = _make_backend_manager( + Linear, + "bmm", + { + "ninetoothed": ops.ninetoothed.torch.bmm, + "triton": ops.triton.torch.bmm, + "torch": torch.bmm, + }, +) diff --git a/modules/rms_norm.py b/modules/rms_norm.py new file mode 100644 index 0000000..e6ad4c0 --- /dev/null +++ b/modules/rms_norm.py @@ -0,0 +1,32 @@ +import torch.nn as nn +import torch.nn.functional as F + +import ops.ninetoothed.torch +import ops.triton.torch +from modules._utils import _make_backend_manager + + +def _torch_fused_rms_norm(x, w, eps): + return F.rms_norm(x, x.shape[-1:], w, eps) + + +class RMSNorm(nn.Module): + fused_rms_norm = None + + def __init__(self, other): + super().__init__() + self.__dict__ = other.__dict__ + + def forward(self, x): + return type(self).fused_rms_norm(x, self.weight, self.variance_epsilon) + + +rms_norm_backend = _make_backend_manager( + RMSNorm, + "fused_rms_norm", + { + "ninetoothed": ops.ninetoothed.torch.fused_rms_norm, + "triton": ops.triton.torch.fused_rms_norm, + "torch": _torch_fused_rms_norm, + }, +) diff --git a/modules/silu.py b/modules/silu.py new file mode 100644 index 0000000..7bdb25b --- /dev/null +++ b/modules/silu.py @@ -0,0 +1,28 @@ +import torch.nn as nn +import torch.nn.functional as F + +import ops.ninetoothed.torch +import ops.triton.torch +from modules._utils import _make_backend_manager + + +class SiLU(nn.Module): + silu = None + + def __init__(self, other): + super().__init__() + self.__dict__ = other.__dict__ + + def forward(self, input): + return type(self).silu(input) + + +silu_backend = _make_backend_manager( + SiLU, + "silu", + { + "ninetoothed": ops.ninetoothed.torch.silu, + "triton": ops.triton.torch.silu, + "torch": F.silu, + }, +) diff --git a/ops/ninetoothed/kernels/max_pool2d.py b/ops/ninetoothed/kernels/max_pool2d.py new file mode 100644 index 0000000..4a8e5ce --- /dev/null +++ b/ops/ninetoothed/kernels/max_pool2d.py @@ -0,0 +1,32 @@ +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Symbol, Tensor + +BLOCK_SIZE = Symbol("BLOCK_SIZE", meta=True) + +WINDOW_HEIGHT = Symbol("WINDOW_HEIGHT", constexpr=True, upper_bound=16) +WINDOW_WIDTH = Symbol("WINDOW_WIDTH", constexpr=True, upper_bound=16) + + +def arrangement(input, output): + input_arranged = input.tile((1, 1, WINDOW_HEIGHT, WINDOW_WIDTH)) + input_arranged = input_arranged.ravel() + input_arranged = input_arranged.flatten(end_dim=4).flatten(start_dim=1) + input_arranged = input_arranged.tile((BLOCK_SIZE, -1)) + + output_arranged = output.tile((1, 1, 1, 1)) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=4).flatten(start_dim=1) + output_arranged = output_arranged.tile((BLOCK_SIZE, -1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + return input_arranged, output_arranged + + +def application(input, output): + output = ntl.max(input, axis=1) # noqa: F841 + + +kernel = ninetoothed.make( + arrangement, application, (Tensor(4, other=float("-inf")), Tensor(4)) +) diff --git a/ops/ninetoothed/torch.py b/ops/ninetoothed/torch.py index fe0824d..d04ffa3 100644 --- a/ops/ninetoothed/torch.py +++ b/ops/ninetoothed/torch.py @@ -7,6 +7,7 @@ import ops.ninetoothed.kernels.bmm import ops.ninetoothed.kernels.conv2d import ops.ninetoothed.kernels.fused_rms_norm +import ops.ninetoothed.kernels.max_pool2d import ops.ninetoothed.kernels.mm import ops.ninetoothed.kernels.rms_norm import ops.ninetoothed.kernels.rotary_position_embedding @@ -70,6 +71,21 @@ def fused_rms_norm(x, w, eps=None): return y_2d.view(x.shape) +def max_pool2d(input, window_shape): + n, c, h, w = input.shape + r, s = window_shape + p = math.ceil((h - r) / r + 1) + q = math.ceil((w - s) / s + 1) + + output = torch.empty(n, c, p, q, dtype=input.dtype, device=input.device) + + ops.ninetoothed.kernels.max_pool2d.kernel( + input, output, WINDOW_HEIGHT=r, WINDOW_WIDTH=s + ) + + return output + + def mm(input, other): output_shape = (input.shape[0], other.shape[1]) output = torch.empty(output_shape, dtype=input.dtype, device=input.device) diff --git a/requirements.txt b/requirements.txt index 8844f71..026f93e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ matplotlib pandas transformers radon +pytest diff --git a/rms_norm.py b/rms_norm.py deleted file mode 100644 index c06ee7f..0000000 --- a/rms_norm.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -import torch.nn.functional as F -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - -if __name__ == "__main__": - torch.manual_seed(0) - - dtype = torch.float16 - device = "cuda" - - input = torch.randn(1151, 8192, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.rms_norm(input) - torch_output = F.rms_norm(input, input.shape[-1:]) - triton_output = ops.triton.torch.rms_norm(input) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["n"], - x_vals=[2**i for i in range(5, 15)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="rms-norm-performance", - args={"m": 4096}, - ) - ) - def benchmark(m, n, provider): - input = torch.randn(m, n, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.rms_norm(input) - torch_output = F.rms_norm(input, input.shape[-1:]) - triton_output = ops.triton.torch.rms_norm(input) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.001, rtol=0.005) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.rms_norm(input)) - elif provider == "torch": - ms = triton.testing.do_bench( - lambda: torch.rms_norm(input, input.shape[-1:]) - ) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: ops.triton.torch.rms_norm(input)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/rotary_position_embedding.py b/rotary_position_embedding.py deleted file mode 100644 index 0ce57e4..0000000 --- a/rotary_position_embedding.py +++ /dev/null @@ -1,122 +0,0 @@ -import torch -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - - -def torch_rotary_position_embedding(input, sin_table, cos_table, interleaved=True): - batch_size, seq_len, num_heads, emb_dim = input.shape - - assert emb_dim % 2 == 0, "The embedding dimension must be even." - - sin_table = sin_table[None, :, None, :] - cos_table = cos_table[None, :, None, :] - - if interleaved: - pair_wise_input = input.view(batch_size, seq_len, num_heads, emb_dim // 2, 2) - input_0, input_1 = pair_wise_input[..., 0], pair_wise_input[..., 1] - input_0_rotated = input_0 * cos_table - input_1 * sin_table - input_1_rotated = input_0 * sin_table + input_1 * cos_table - - return torch.stack((input_0_rotated, input_1_rotated), dim=-1).view(input.shape) - else: - input_0 = input[..., : input.shape[-1] // 2] - input_1 = input[..., input.shape[-1] // 2 :] - input_0_rotated = input_0 * cos_table - input_1 * sin_table - input_1_rotated = input_0 * sin_table + input_1 * cos_table - - return torch.cat((input_0_rotated, input_1_rotated), dim=-1) - - -def _generate_sin_and_cos_tables( - seq_len, emb_dim, base=10000, dtype=torch.float32, device="cuda" -): - assert emb_dim % 2 == 0, "The embedding dimension must be even." - - theta = base ** ( - -2 * (torch.arange(emb_dim // 2, dtype=dtype, device=device) / emb_dim) - ) - - positions = torch.arange(seq_len, dtype=dtype, device=device).unsqueeze(1) - sin_table = torch.sin(positions * theta) - cos_table = torch.cos(positions * theta) - - return sin_table, cos_table - - -if __name__ == "__main__": - torch.manual_seed(0) - - batch_size, seq_len, num_heads, emb_dim = 4, 128, 8, 64 - dtype = torch.float32 - device = "cuda" - - sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) - x = torch.randn(batch_size, seq_len, num_heads, emb_dim, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.rotary_position_embedding( - x, sin_table, cos_table, interleaved=False - ) - torch_output = torch_rotary_position_embedding( - x, sin_table, cos_table, interleaved=False - ) - triton_output = ops.triton.torch.rotary_position_embedding( - x, sin_table, cos_table, interleaved=False - ) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0.001): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["seq_len"], - x_vals=[2**i for i in range(5, 15)], - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="rotary-position-embedding-performance", - args={}, - ) - ) - def benchmark(seq_len, provider): - batch_size, num_heads, emb_dim = 4, 32, 64 - shape = (batch_size, seq_len, num_heads, emb_dim) - dtype = torch.float16 - device = "cuda" - - sin_table, cos_table = _generate_sin_and_cos_tables(seq_len, emb_dim) - x = torch.randn(shape, dtype=dtype, device=device) - - if provider == "ninetoothed": - ms = triton.testing.do_bench( - lambda: ops.ninetoothed.torch.rotary_position_embedding( - x, sin_table, cos_table - ) - ) - elif provider == "torch": - ms = triton.testing.do_bench( - lambda: torch_rotary_position_embedding(x, sin_table, cos_table) - ) - elif provider == "triton": - ms = triton.testing.do_bench( - lambda: ops.triton.torch.rotary_position_embedding( - x, sin_table, cos_table - ) - ) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/run_experiments.py b/run_experiments.py index 2c5181e..2e99231 100644 --- a/run_experiments.py +++ b/run_experiments.py @@ -10,7 +10,7 @@ import ops.ninetoothed.torch import ops.triton.torch -import rotary_position_embedding +from modules import torch_rotary_position_embedding PROMPTS = ( "The emergence of deep learning domain-specific languages (DSLs) has substantially reduced the obstacles in developing high-performance, cross-platform compute kernels, but current DSLs", @@ -31,7 +31,7 @@ def _run_task(op_name, dtype, device, *arg_shapes, **kwarg_shapes): triton_op = getattr(ops.triton.torch, op_name) if op_name == "rotary_position_embedding": - torch_op = rotary_position_embedding.torch_rotary_position_embedding + torch_op = torch_rotary_position_embedding else: torch_op = ( getattr(torch, op_name) diff --git a/scaled_dot_product_attention.py b/scaled_dot_product_attention.py deleted file mode 100644 index ad44b77..0000000 --- a/scaled_dot_product_attention.py +++ /dev/null @@ -1,196 +0,0 @@ -from contextlib import contextmanager - -import torch -import torch.nn as nn -import torch.nn.functional as F -import triton -from transformers.models.llama.modeling_llama import repeat_kv - -import ops.ninetoothed.torch -import ops.triton.torch -from rotary_position_embedding import torch_rotary_position_embedding - - -class Attention(nn.Module): - scaled_dot_product_attention = None - - rotary_position_embedding = None - - def __init__(self, other): - super().__init__() - - self.__dict__ = other.__dict__ - - def forward( - self, - hidden_states, - position_embeddings, - attention_mask, - past_key_value, - cache_position, - **kwargs, - ): - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) - - query_states = self.q_proj(hidden_states).view(hidden_shape) - key_states = self.k_proj(hidden_states).view(hidden_shape) - value_states = self.v_proj(hidden_states).view(hidden_shape) - - cos_table, sin_table = position_embeddings - sin_table = sin_table[0, ..., sin_table.shape[-1] // 2 :] - cos_table = cos_table[0, ..., cos_table.shape[-1] // 2 :] - - query_states = type(self).rotary_position_embedding( - query_states, sin_table, cos_table, interleaved=False - ) - key_states = type(self).rotary_position_embedding( - key_states, sin_table, cos_table, interleaved=False - ) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - if past_key_value is not None: - cache_kwargs = { - "sin": sin_table, - "cos": cos_table, - "cache_position": cache_position, - } - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_output = type(self).scaled_dot_product_attention( - query_states, key_states, value_states, scale=self.scaling - ) - attn_output = attn_output.transpose(1, 2) - - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - - return attn_output, None - - -@contextmanager -def scaled_dot_product_attention_backend(backend_name): - _prev_impl = Attention.scaled_dot_product_attention - - if backend_name == "ninetoothed": - impl = ops.ninetoothed.torch.scaled_dot_product_attention - elif backend_name == "triton": - impl = ops.triton.torch.scaled_dot_product_attention - elif backend_name == "torch": - impl = F.scaled_dot_product_attention - else: - raise ValueError(f"unknown backend: `{backend_name}`") - - Attention.scaled_dot_product_attention = impl - - try: - yield - finally: - Attention.scaled_dot_product_attention = _prev_impl - - -@contextmanager -def rotary_position_embedding_backend(backend_name): - _prev_impl = Attention.rotary_position_embedding - - if backend_name == "ninetoothed": - impl = ops.ninetoothed.torch.rotary_position_embedding - elif backend_name == "triton": - impl = ops.triton.torch.rotary_position_embedding - elif backend_name == "torch": - impl = torch_rotary_position_embedding - else: - raise ValueError(f"unknown backend: `{backend_name}`") - - Attention.rotary_position_embedding = impl - - try: - yield - finally: - Attention.rotary_position_embedding = _prev_impl - - -if __name__ == "__main__": - torch.manual_seed(0) - - q_o_shape = (2, 8, 1024, 64) - k_v_shape = (2, 8, 1024, 64) - dtype = torch.float16 - device = "cuda" - - q = torch.randn(q_o_shape, dtype=dtype, device=device) - k = torch.randn(k_v_shape, dtype=dtype, device=device) - v = torch.randn(k_v_shape, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) - torch_output = F.scaled_dot_product_attention(q, k, v) - triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0.01): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=1e-3, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["seq_len"], - x_vals=[2**i for i in range(7, 17)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="scaled-dot-product-attention-performance", - args={}, - ) - ) - def benchmark(seq_len, provider): - batch_size, num_heads, emb_dim = 4, 32, 64 - shape = (batch_size, num_heads, seq_len, emb_dim) - dtype = torch.float16 - device = "cuda" - - q = torch.randn(shape, dtype=dtype, device=device) - k = torch.randn(shape, dtype=dtype, device=device) - v = torch.randn(shape, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) - torch_output = F.scaled_dot_product_attention(q, k, v) - triton_output = ops.triton.torch.scaled_dot_product_attention(q, k, v) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.025, rtol=0.025) - assert torch.allclose(ninetoothed_output, triton_output, atol=0.001, rtol=0.001) - - if provider == "ninetoothed": - ms = triton.testing.do_bench( - lambda: ops.ninetoothed.torch.scaled_dot_product_attention(q, k, v) - ) - elif provider == "torch": - ms = triton.testing.do_bench( - lambda: F.scaled_dot_product_attention(q, k, v) - ) - elif provider == "triton": - ms = triton.testing.do_bench( - lambda: ops.triton.torch.scaled_dot_product_attention(q, k, v) - ) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/silu.py b/silu.py deleted file mode 100644 index fafae5d..0000000 --- a/silu.py +++ /dev/null @@ -1,104 +0,0 @@ -from contextlib import contextmanager - -import torch -import torch.nn as nn -import torch.nn.functional as F -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - - -class SiLU(nn.Module): - silu = None - - def __init__(self, other): - super().__init__() - - self.__dict__ = other.__dict__ - - def forward(self, input): - return type(self).silu(input) - - -@contextmanager -def silu_backend(backend_name): - _prev_impl = SiLU.silu - - if backend_name == "ninetoothed": - impl = ops.ninetoothed.torch.silu - elif backend_name == "triton": - impl = ops.triton.torch.silu - elif backend_name == "torch": - impl = F.silu - else: - raise ValueError(f"unknown backend: `{backend_name}`") - - SiLU.silu = impl - - try: - yield - finally: - SiLU.silu = _prev_impl - - -if __name__ == "__main__": - torch.manual_seed(0) - - shape = (8, 256, 512) - dtype = torch.float16 - device = "cuda" - - input = torch.randn(shape, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.silu(input) - torch_output = F.silu(input) - triton_output = ops.triton.torch.silu(input) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=1e-3, rtol=1e-3): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["m", "n", "k"], - x_vals=[2**i for i in range(3, 10)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="silu-performance", - args={}, - ) - ) - def benchmark(m, n, k, provider): - input = torch.randn(m, n, k, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.silu(input) - torch_output = F.silu(input) - triton_output = ops.triton.torch.silu(input) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.001) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.silu(input)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: F.silu(input)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: ops.triton.torch.silu(input)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/softmax.py b/softmax.py deleted file mode 100644 index f979928..0000000 --- a/softmax.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - -if __name__ == "__main__": - torch.manual_seed(0) - - dtype = torch.float16 - device = "cuda" - - input = torch.randn(1823, 781, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.softmax(input) - torch_output = torch.softmax(input, axis=-1) - triton_output = ops.triton.torch.softmax(input) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0.001): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["n"], - x_vals=[2**i for i in range(5, 15)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="softmax-performance", - args={"m": 4096}, - ) - ) - def benchmark(m, n, provider): - input = torch.randn(m, n, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.softmax(input) - torch_output = torch.softmax(input, axis=-1) - triton_output = ops.triton.torch.softmax(input) - - assert torch.allclose(ninetoothed_output, torch_output, atol=0.001) - assert torch.allclose(ninetoothed_output, triton_output, atol=0, rtol=0) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.softmax(input)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch.softmax(input, axis=-1)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: ops.triton.torch.softmax(input)) - - return ms - - benchmark.run(show_plots=True, print_data=True, save_path=".") diff --git a/swiglu.py b/swiglu.py deleted file mode 100644 index edf70c1..0000000 --- a/swiglu.py +++ /dev/null @@ -1,73 +0,0 @@ -import torch -import torch.nn.functional as F -import triton - -import ops.ninetoothed.torch -import ops.triton.torch - - -def torch_swiglu( - a: torch.Tensor, - b: torch.Tensor, -) -> torch.Tensor: - return a * F.silu(b) - - -if __name__ == "__main__": - torch.manual_seed(0) - - shape = (13, 3) - dtype = torch.float16 - device = "cuda" - - a = torch.randn(shape, dtype=dtype, device=device) - b = torch.randn(shape, dtype=dtype, device=device) - c = torch.randn(shape, dtype=dtype, device=device) - - ninetoothed_output = ops.ninetoothed.torch.swiglu(a, b) - torch_output = torch_swiglu(a, b) - triton_output = ops.triton.torch.swiglu(a, b) - - print(ninetoothed_output) - print(torch_output) - print(triton_output) - - if torch.allclose(ninetoothed_output, torch_output, atol=0, rtol=1e-3): - print("✅ NineToothed and PyTorch match.") - else: - print("❌ NineToothed and PyTorch differ.") - if torch.allclose(ninetoothed_output, triton_output): - print("✅ NineToothed and Triton match.") - else: - print("❌ NineToothed and Triton differ.") - - @triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["m", "n"], - x_vals=[128 * i for i in range(2, 50)], - x_log=True, - line_arg="provider", - line_vals=["ninetoothed", "torch", "triton"], - line_names=["NineToothed", "PyTorch", "Triton"], - styles=[("blue", "-"), ("green", "-"), ("orange", "-")], - ylabel="ms", - plot_name="swiglu-performance", - args={}, - ) - ) - def benchmark(m, n, provider): - shape = (m, n) - - a = torch.randn(shape, dtype=dtype, device=device) - b = torch.randn(shape, dtype=dtype, device=device) - - if provider == "ninetoothed": - ms = triton.testing.do_bench(lambda: ops.ninetoothed.torch.swiglu(a, b)) - elif provider == "torch": - ms = triton.testing.do_bench(lambda: torch_swiglu(a, b)) - elif provider == "triton": - ms = triton.testing.do_bench(lambda: ops.triton.torch.swiglu(a, b)) - - return ms - - benchmark.run(print_data=True, show_plots=True, save_path=".") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..902b1e1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,27 @@ +import pytest +import torch + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line( + "markers", "benchmark: performance benchmarks (deselected by default)" + ) + + +def pytest_collection_modifyitems(config, items): + """Skip benchmark tests unless explicitly selected via ``-m benchmark``.""" + if config.getoption("-m") == "benchmark": + return + + skip = pytest.mark.skip(reason="benchmarks are not selected, use `-m benchmark`") + + for item in items: + if "benchmark" in item.keywords: + item.add_marker(skip) + + +@pytest.fixture(autouse=True) +def seed(): + """Set a fixed random seed before each test for reproducibility.""" + torch.manual_seed(0) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py new file mode 100644 index 0000000..fda79e4 --- /dev/null +++ b/tests/test_benchmarks.py @@ -0,0 +1,369 @@ +import random + +import pytest +import torch +import torch.nn.functional as F + +import bench +import ops.ninetoothed.torch +import ops.triton.torch +from modules import generate_sin_and_cos_tables, torch_rotary_position_embedding + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available"), + pytest.mark.benchmark, +] + +DTYPE = torch.float16 +DEVICE = "cuda" + + +class TestAddBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.add, + "torch": torch.add, + "triton": ops.triton.torch.add, + } + + def make_inputs(size): + return ( + torch.randn(size, dtype=DTYPE, device=DEVICE), + torch.randn(size, dtype=DTYPE, device=DEVICE), + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["size"], + x_vals=[2**i for i in range(18, 28)], + tolerances={"triton": {"atol": 0, "rtol": 0}}, + name="add", + ) + + +class TestMMBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.mm, + "torch": torch.mm, + "triton": ops.triton.torch.mm, + } + + def make_inputs(m, n, k): + return ( + torch.randn((m, k), dtype=DTYPE, device=DEVICE), + torch.randn((k, n), dtype=DTYPE, device=DEVICE), + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 13)], + tolerances={ + "torch": {"atol": 0.025, "rtol": 0.025}, + "triton": {"atol": 0, "rtol": 0}, + }, + name="mm", + ) + + +class TestBMMBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.bmm, + "torch": torch.bmm, + "triton": ops.triton.torch.bmm, + } + + def make_inputs(b, m, n, k): + return ( + torch.randn((b, m, k), dtype=DTYPE, device=DEVICE), + torch.randn((b, k, n), dtype=DTYPE, device=DEVICE), + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 13)], + benchmark_args={"b": 4}, + tolerances={"triton": {"atol": 0, "rtol": 0}}, + name="bmm", + ) + + +class TestAddMMBenchmark: + def test_benchmark(self): + random.seed(0) + + impls = { + "ninetoothed": ops.ninetoothed.torch.addmm, + "torch": torch.addmm, + "triton": ops.triton.torch.addmm, + } + + def make_inputs(m, n, k): + return ( + torch.randn((m, n), dtype=DTYPE, device=DEVICE), + torch.randn((m, k), dtype=DTYPE, device=DEVICE), + torch.randn((k, n), dtype=DTYPE, device=DEVICE), + ), {"beta": random.uniform(0, 1), "alpha": random.uniform(0, 1)} + + bench.benchmark( + impls, + make_inputs, + x_names=["m", "n", "k"], + x_vals=[128 * i for i in range(2, 33)], + x_log=False, + assert_correctness=False, + name="addmm", + ) + + +class TestConv2DBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.conv2d, + "torch": F.conv2d, + "triton": ops.triton.torch.conv2d, + } + + def make_inputs(n): + c, h, w = 512, 14, 14 + k, r, s = 512, 3, 3 + + return ( + torch.randn((n, c, h, w), dtype=DTYPE, device=DEVICE), + torch.randn((k, c, r, s), dtype=DTYPE, device=DEVICE), + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["n"], + x_vals=[2**i for i in range(1, 11)], + tolerances={ + "torch": {"atol": 0.01, "rtol": 0.01}, + "triton": {"atol": 0, "rtol": 0}, + }, + name="conv2d", + ) + + +class TestSoftmaxBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.softmax, + "torch": lambda input: torch.softmax(input, dim=-1), + "triton": ops.triton.torch.softmax, + } + + def make_inputs(m, n): + return (torch.randn(m, n, dtype=DTYPE, device=DEVICE),), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["n"], + x_vals=[2**i for i in range(5, 15)], + benchmark_args={"m": 4096}, + tolerances={ + "torch": {"atol": 0.001}, + "triton": {"atol": 0, "rtol": 0}, + }, + name="softmax", + ) + + +class TestRMSNormBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.rms_norm, + "torch": lambda input: F.rms_norm(input, input.shape[-1:]), + "triton": ops.triton.torch.rms_norm, + } + + def make_inputs(m, n): + return (torch.randn(m, n, dtype=DTYPE, device=DEVICE),), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["n"], + x_vals=[2**i for i in range(5, 15)], + benchmark_args={"m": 4096}, + tolerances={ + "torch": {"atol": 0.001, "rtol": 0.005}, + "triton": {"atol": 0, "rtol": 0}, + }, + name="rms-norm", + ) + + +class TestFusedRMSNormBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.fused_rms_norm, + "torch": lambda x, w, eps: F.rms_norm(x, x.shape[-1:], w, eps), + "triton": ops.triton.torch.fused_rms_norm, + } + + def make_inputs(m, n): + return ( + torch.randn(m, n, dtype=DTYPE, device=DEVICE), + torch.randn(n, dtype=DTYPE, device=DEVICE), + 1e-5, + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["n"], + x_vals=[2**i for i in range(5, 15)], + benchmark_args={"m": 4096}, + tolerances={ + "torch": {"atol": 0.001, "rtol": 0.005}, + "triton": {"atol": 0.001, "rtol": 0.005}, + }, + name="fused-rms-norm", + ) + + +class TestSiLUBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.silu, + "torch": F.silu, + "triton": ops.triton.torch.silu, + } + + def make_inputs(m, n, k): + return (torch.randn(m, n, k, dtype=DTYPE, device=DEVICE),), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["m", "n", "k"], + x_vals=[2**i for i in range(3, 10)], + tolerances={ + "torch": {"atol": 0.001}, + "triton": {"atol": 0, "rtol": 0}, + }, + name="silu", + ) + + +class TestSwiGLUBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.swiglu, + "torch": lambda a, b: a * F.silu(b), + "triton": ops.triton.torch.swiglu, + } + + def make_inputs(m, n): + return ( + torch.randn((m, n), dtype=DTYPE, device=DEVICE), + torch.randn((m, n), dtype=DTYPE, device=DEVICE), + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["m", "n"], + x_vals=[128 * i for i in range(2, 50)], + assert_correctness=False, + name="swiglu", + ) + + +class TestRotaryPositionEmbeddingBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.rotary_position_embedding, + "torch": torch_rotary_position_embedding, + "triton": ops.triton.torch.rotary_position_embedding, + } + + def make_inputs(seq_len): + batch_size, num_heads, emb_dim = 4, 32, 64 + sin_table, cos_table = generate_sin_and_cos_tables(seq_len, emb_dim) + + return ( + torch.randn( + (batch_size, seq_len, num_heads, emb_dim), + dtype=torch.float16, + device=DEVICE, + ), + sin_table, + cos_table, + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["seq_len"], + x_vals=[2**i for i in range(5, 15)], + x_log=False, + assert_correctness=False, + name="rotary-position-embedding", + ) + + +class TestScaledDotProductAttentionBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.scaled_dot_product_attention, + "torch": F.scaled_dot_product_attention, + "triton": ops.triton.torch.scaled_dot_product_attention, + } + + def make_inputs(seq_len): + batch_size, num_heads, emb_dim = 4, 32, 64 + shape = (batch_size, num_heads, seq_len, emb_dim) + + return ( + torch.randn(shape, dtype=DTYPE, device=DEVICE), + torch.randn(shape, dtype=DTYPE, device=DEVICE), + torch.randn(shape, dtype=DTYPE, device=DEVICE), + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["seq_len"], + x_vals=[2**i for i in range(7, 17)], + tolerances={ + "torch": {"atol": 0.025, "rtol": 0.025}, + "triton": {"atol": 0.001, "rtol": 0.001}, + }, + name="scaled-dot-product-attention", + ) + + +class TestMaxPool2DBenchmark: + def test_benchmark(self): + impls = { + "ninetoothed": ops.ninetoothed.torch.max_pool2d, + "torch": lambda input, window_shape: F.max_pool2d( + input, window_shape, ceil_mode=True + ), + } + + def make_inputs(h, w): + return ( + torch.randn((64, 64, h, w), dtype=DTYPE, device=DEVICE), + (3, 3), + ), {} + + bench.benchmark( + impls, + make_inputs, + x_names=["h", "w"], + x_vals=[8 * i for i in range(2, 33)], + x_log=False, + assert_correctness=False, + name="max-pool2d", + ) diff --git a/tests/test_ops.py b/tests/test_ops.py new file mode 100644 index 0000000..a89940a --- /dev/null +++ b/tests/test_ops.py @@ -0,0 +1,259 @@ +import random + +import pytest +import torch +import torch.nn.functional as F + +import ops.ninetoothed.torch +import ops.triton.torch +from bench import assert_match +from modules import generate_sin_and_cos_tables, torch_rotary_position_embedding + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" +) + +DTYPE = torch.float16 +DEVICE = "cuda" + + +class TestAdd: + def test_correctness(self): + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.add, + "torch": torch.add, + "triton": ops.triton.torch.add, + }, + args=( + torch.randn(98432, dtype=DTYPE, device=DEVICE), + torch.randn(98432, dtype=DTYPE, device=DEVICE), + ), + tolerances={"triton": {"atol": 0, "rtol": 0}}, + ) + + +class TestMM: + def test_correctness(self): + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.mm, + "torch": torch.mm, + "triton": ops.triton.torch.mm, + }, + args=( + torch.randn((512, 512), dtype=DTYPE, device=DEVICE), + torch.randn((512, 512), dtype=DTYPE, device=DEVICE), + ), + tolerances={"triton": {"atol": 0, "rtol": 0}}, + ) + + +class TestBMM: + def test_correctness(self): + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.bmm, + "torch": torch.bmm, + "triton": ops.triton.torch.bmm, + }, + args=( + torch.randn((4, 512, 1024), dtype=DTYPE, device=DEVICE), + torch.randn((4, 1024, 2028), dtype=DTYPE, device=DEVICE), + ), + tolerances={"triton": {"atol": 0, "rtol": 0}}, + ) + + +class TestAddMM: + def test_correctness(self): + random.seed(0) + shape = (512, 512) + beta = random.uniform(0, 1) + alpha = random.uniform(0, 1) + + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.addmm, + "torch": torch.addmm, + "triton": ops.triton.torch.addmm, + }, + args=( + torch.randn(shape, dtype=DTYPE, device=DEVICE), + torch.randn(shape, dtype=DTYPE, device=DEVICE), + torch.randn(shape, dtype=DTYPE, device=DEVICE), + ), + kwargs={"beta": beta, "alpha": alpha}, + tolerances={"torch": {"atol": 0.01, "rtol": 0.01}}, + ) + + +class TestConv2D: + def test_correctness(self): + n, c, h, w = 4, 3, 224, 224 + k, r, s = 8, 3, 3 + + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.conv2d, + "torch": F.conv2d, + "triton": ops.triton.torch.conv2d, + }, + args=( + torch.randn(n, c, h, w, dtype=DTYPE, device=DEVICE), + torch.randn(k, c, r, s, dtype=DTYPE, device=DEVICE), + ), + tolerances={ + "torch": {"atol": 0.01, "rtol": 0.01}, + "triton": {"atol": 0, "rtol": 0}, + }, + ) + + +class TestSoftmax: + def test_correctness(self): + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.softmax, + "torch": lambda input: torch.softmax(input, dim=-1), + "triton": ops.triton.torch.softmax, + }, + args=(torch.randn(1823, 781, dtype=DTYPE, device=DEVICE),), + tolerances={ + "torch": {"atol": 0.001}, + "triton": {"atol": 0, "rtol": 0}, + }, + ) + + +class TestRMSNorm: + def test_correctness(self): + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.rms_norm, + "torch": lambda input: F.rms_norm(input, input.shape[-1:]), + "triton": ops.triton.torch.rms_norm, + }, + args=(torch.randn(1151, 8192, dtype=DTYPE, device=DEVICE),), + tolerances={ + "torch": {"atol": 0.001, "rtol": 0.005}, + "triton": {"atol": 0, "rtol": 0}, + }, + ) + + +class TestFusedRMSNorm: + def test_correctness(self): + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.fused_rms_norm, + "torch": lambda x, w, eps: F.rms_norm(x, x.shape[-1:], w, eps), + "triton": ops.triton.torch.fused_rms_norm, + }, + args=( + torch.randn(1151, 8192, dtype=DTYPE, device=DEVICE), + torch.randn(8192, dtype=DTYPE, device=DEVICE), + 1e-5, + ), + tolerances={ + "torch": {"atol": 0.001, "rtol": 0.005}, + "triton": {"atol": 0.001, "rtol": 0.005}, + }, + ) + + +class TestSiLU: + def test_correctness(self): + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.silu, + "torch": F.silu, + "triton": ops.triton.torch.silu, + }, + args=(torch.randn((8, 256, 512), dtype=DTYPE, device=DEVICE),), + tolerances={ + "torch": {"atol": 1e-3, "rtol": 1e-3}, + "triton": {"atol": 0, "rtol": 0}, + }, + ) + + +class TestSwiGLU: + def test_correctness(self): + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.swiglu, + "torch": lambda a, b: a * F.silu(b), + "triton": ops.triton.torch.swiglu, + }, + args=( + torch.randn((13, 3), dtype=DTYPE, device=DEVICE), + torch.randn((13, 3), dtype=DTYPE, device=DEVICE), + ), + tolerances={"torch": {"atol": 0, "rtol": 1e-3}}, + ) + + +class TestRotaryPositionEmbedding: + def test_correctness(self): + batch_size, seq_len, num_heads, emb_dim = 4, 128, 8, 64 + sin_table, cos_table = generate_sin_and_cos_tables(seq_len, emb_dim) + + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.rotary_position_embedding, + "torch": torch_rotary_position_embedding, + "triton": ops.triton.torch.rotary_position_embedding, + }, + args=( + torch.randn( + batch_size, + seq_len, + num_heads, + emb_dim, + dtype=torch.float32, + device=DEVICE, + ), + sin_table, + cos_table, + ), + kwargs={"interleaved": False}, + tolerances={ + "torch": {"atol": 0.001}, + "triton": {"atol": 0, "rtol": 0}, + }, + ) + + +class TestScaledDotProductAttention: + def test_correctness(self): + q_o_shape = (2, 8, 1024, 64) + k_v_shape = (2, 8, 1024, 64) + + assert_match( + { + "ninetoothed": ops.ninetoothed.torch.scaled_dot_product_attention, + "torch": F.scaled_dot_product_attention, + "triton": ops.triton.torch.scaled_dot_product_attention, + }, + args=( + torch.randn(q_o_shape, dtype=DTYPE, device=DEVICE), + torch.randn(k_v_shape, dtype=DTYPE, device=DEVICE), + torch.randn(k_v_shape, dtype=DTYPE, device=DEVICE), + ), + tolerances={ + "torch": {"atol": 0.01}, + "triton": {"atol": 1e-3, "rtol": 0}, + }, + ) + + +class TestMaxPool2D: + def test_correctness(self): + input = torch.randn((32, 3, 64, 64), dtype=DTYPE, device=DEVICE) + window_shape = (3, 3) + + ninetoothed_output = ops.ninetoothed.torch.max_pool2d(input, window_shape) + torch_output = F.max_pool2d(input, window_shape, ceil_mode=True) + + assert torch.allclose(ninetoothed_output, torch_output) diff --git a/utils.py b/utils.py deleted file mode 100644 index 44641c8..0000000 --- a/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -def replace_module(module, replacement_class): - for child_name, child_module in module.named_children(): - if replacement_class.__name__ not in child_module.__class__.__name__: - replace_module(child_module, replacement_class) - continue - - replacement = replacement_class(child_module) - setattr(module, child_name, replacement) - - -def find_module_types(module): - types = {type(module)} - - for child_module in module.children(): - types.update(find_module_types(child_module)) - - return types