|
| 1 | +import torch |
| 2 | +import triton |
| 3 | + |
| 4 | +DISPLAY_NAMES = { |
| 5 | + "ninetoothed": "NineToothed", |
| 6 | + "torch": "PyTorch", |
| 7 | + "triton": "Triton", |
| 8 | +} |
| 9 | + |
| 10 | +STYLES = [ |
| 11 | + ("blue", "-"), |
| 12 | + ("green", "-"), |
| 13 | + ("orange", "-"), |
| 14 | + ("red", "-"), |
| 15 | + ("purple", "-"), |
| 16 | + ("cyan", "-"), |
| 17 | +] |
| 18 | + |
| 19 | + |
| 20 | +def assert_match(impls, args, kwargs=None, tolerances=None): |
| 21 | + """Assert that all implementations produce matching outputs. |
| 22 | +
|
| 23 | + Same API as ``check``, but raises ``AssertionError`` on mismatch |
| 24 | + instead of printing. Intended for use in test suites. |
| 25 | +
|
| 26 | + :param impls: Ordered dict mapping provider name to callable. |
| 27 | + :param args: Tuple of positional arguments. |
| 28 | + :param kwargs: Dict of keyword arguments. |
| 29 | + :param tolerances: Dict mapping provider name to ``torch.allclose`` kwargs. |
| 30 | + """ |
| 31 | + kwargs = kwargs or {} |
| 32 | + tolerances = tolerances or {} |
| 33 | + results = {name: fn(*args, **kwargs) for name, fn in impls.items()} |
| 34 | + |
| 35 | + names = list(impls) |
| 36 | + reference_name = names[0] |
| 37 | + reference = results[reference_name] |
| 38 | + |
| 39 | + for name in names[1:]: |
| 40 | + tol = tolerances.get(name, {}) |
| 41 | + ref_display = _display_name(reference_name) |
| 42 | + other_display = _display_name(name) |
| 43 | + |
| 44 | + assert torch.allclose(reference, results[name], **tol), ( |
| 45 | + f"{ref_display} and {other_display} outputs differ." |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +def benchmark( |
| 50 | + impls, |
| 51 | + make_inputs, |
| 52 | + x_names, |
| 53 | + x_vals, |
| 54 | + name, |
| 55 | + benchmark_args=None, |
| 56 | + x_log=True, |
| 57 | + assert_correctness=True, |
| 58 | + tolerances=None, |
| 59 | + save_path=".", |
| 60 | +): |
| 61 | + """Create and run a performance benchmark. |
| 62 | +
|
| 63 | + :param impls: Ordered dict mapping provider name to callable. |
| 64 | + :param make_inputs: Callable returning ``(args_tuple, kwargs_dict)``. |
| 65 | + :param x_names: List of benchmark parameter names. |
| 66 | + :param x_vals: List of benchmark parameter values. |
| 67 | + :param name: Operator name, used for the plot filename. |
| 68 | + :param benchmark_args: Fixed benchmark args dict. |
| 69 | + :param x_log: Whether to use log scale for the x-axis. |
| 70 | + :param tolerances: Dict mapping provider name to ``torch.allclose`` kwargs. |
| 71 | + :param assert_correctness: Whether to assert correctness at each point. |
| 72 | + :param save_path: Directory to save plot files, or ``None`` to skip saving. |
| 73 | + """ |
| 74 | + providers = list(impls) |
| 75 | + tolerances = tolerances or {} |
| 76 | + |
| 77 | + @triton.testing.perf_report( |
| 78 | + triton.testing.Benchmark( |
| 79 | + x_names=x_names, |
| 80 | + x_vals=x_vals, |
| 81 | + line_arg="provider", |
| 82 | + line_vals=providers, |
| 83 | + line_names=[_display_name(p) for p in providers], |
| 84 | + plot_name=f"{name}-performance", |
| 85 | + args=benchmark_args or {}, |
| 86 | + ylabel="ms", |
| 87 | + x_log=x_log, |
| 88 | + styles=[_style(i) for i in range(len(providers))], |
| 89 | + ) |
| 90 | + ) |
| 91 | + def bench(provider, **params): |
| 92 | + args, kwargs = make_inputs(**params) |
| 93 | + |
| 94 | + if assert_correctness: |
| 95 | + results = {p: impls[p](*args, **kwargs) for p in providers} |
| 96 | + reference = results[providers[0]] |
| 97 | + |
| 98 | + for p in providers[1:]: |
| 99 | + tol = tolerances.get(p, {}) |
| 100 | + assert torch.allclose(reference, results[p], **tol) |
| 101 | + |
| 102 | + return triton.testing.do_bench(lambda: impls[provider](*args, **kwargs)) |
| 103 | + |
| 104 | + bench.run(print_data=True, save_path=save_path) |
| 105 | + |
| 106 | + |
| 107 | +def _display_name(name): |
| 108 | + """Return the display name for a provider.""" |
| 109 | + |
| 110 | + return DISPLAY_NAMES.get(name, name) |
| 111 | + |
| 112 | + |
| 113 | +def _style(index): |
| 114 | + """Return a plot style, cycling through available options.""" |
| 115 | + |
| 116 | + return STYLES[index % len(STYLES)] |
0 commit comments