|
12 | 12 |
|
13 | 13 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) |
14 | 14 |
|
| 15 | +from benchmark_model_configs import MODEL_REGISTRY |
| 16 | +from benchmark_model_configs import compute_model_config_sweep_config |
15 | 17 | from benchmark_model_configs import compute_seq_len_sweep_config |
16 | 18 | from benchmark_model_configs import estimate_kernel_peak_memory |
17 | 19 | from benchmark_model_configs import get_benchmark_model_config |
@@ -69,61 +71,154 @@ def bench_memory_attn_res(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO |
69 | 71 | return run_memory_benchmark(fn, input.kernel_operation_mode) |
70 | 72 |
|
71 | 73 |
|
72 | | -if __name__ == "__main__": |
73 | | - args = parse_benchmark_script_args() |
74 | | - |
75 | | - model = get_benchmark_model_config(args.model) |
76 | | - probe_seq_len = 1024 |
77 | | - |
78 | | - def _probe(): |
79 | | - probe_input = SingleBenchmarkRunInput( |
80 | | - x=probe_seq_len, |
81 | | - kernel_provider="pytorch", |
| 74 | +def _resolve_model_config_attn_res(input: SingleBenchmarkRunInput): |
| 75 | + """Resolve model-config-sweep input into standard setup args.""" |
| 76 | + cfg = input.extra_benchmark_config |
| 77 | + model_info = cfg["model_configs"][input.x] |
| 78 | + return _setup_attn_res( |
| 79 | + SingleBenchmarkRunInput( |
| 80 | + x=cfg["seq_len"], |
| 81 | + kernel_provider=input.kernel_provider, |
82 | 82 | extra_benchmark_config={ |
83 | | - "N": 8, |
84 | | - "bsz": 1, |
85 | | - "hidden_size": model.hidden_size, |
86 | | - "dtype": model.dtype, |
87 | | - "eps": 1e-6, |
| 83 | + "N": cfg["N"], |
| 84 | + "bsz": cfg["bsz"], |
| 85 | + "hidden_size": model_info["hidden_size"], |
| 86 | + "dtype": model_info["dtype"], |
| 87 | + "eps": cfg.get("eps", 1e-6), |
88 | 88 | }, |
89 | 89 | ) |
90 | | - V, fn = _setup_attn_res(probe_input) |
91 | | - return fn() |
92 | | - |
93 | | - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) |
94 | | - kernel_bpt = peak_bytes // probe_seq_len |
95 | | - |
96 | | - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) |
97 | | - |
98 | | - common_configs = { |
99 | | - "kernel_name": "attn_res", |
100 | | - "x_name": "T", |
101 | | - "x_label": "sequence length", |
102 | | - "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], |
103 | | - "kernel_providers": ["liger", "pytorch"], |
104 | | - "extra_benchmark_configs": [ |
105 | | - { |
106 | | - "N": 8, |
107 | | - "bsz": config.batch_size, |
108 | | - "hidden_size": model.hidden_size, |
109 | | - "dtype": model.dtype, |
110 | | - "eps": 1e-6, |
111 | | - } |
112 | | - ], |
113 | | - "overwrite": args.overwrite, |
114 | | - } |
115 | | - |
116 | | - run_benchmarks( |
117 | | - bench_test_fn=bench_speed_attn_res, |
118 | | - kernel_operation_modes=["full", "forward", "backward"], |
119 | | - metric_name="speed", |
120 | | - metric_unit="ms", |
121 | | - **common_configs, |
122 | | - ) |
123 | | - run_benchmarks( |
124 | | - bench_test_fn=bench_memory_attn_res, |
125 | | - kernel_operation_modes=["full", "forward", "backward"], |
126 | | - metric_name="memory", |
127 | | - metric_unit="MB", |
128 | | - **common_configs, |
129 | 90 | ) |
| 91 | + |
| 92 | + |
| 93 | +def bench_speed_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 94 | + V, fn = _resolve_model_config_attn_res(input) |
| 95 | + return run_speed_benchmark(fn, input.kernel_operation_mode, [V]) |
| 96 | + |
| 97 | + |
| 98 | +def bench_memory_attn_res_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 99 | + V, fn = _resolve_model_config_attn_res(input) |
| 100 | + return run_memory_benchmark(fn, input.kernel_operation_mode) |
| 101 | + |
| 102 | + |
| 103 | +if __name__ == "__main__": |
| 104 | + args = parse_benchmark_script_args() |
| 105 | + |
| 106 | + if args.sweep_mode == "model_config": |
| 107 | + all_model_configs = list(MODEL_REGISTRY.values()) |
| 108 | + |
| 109 | + def _probe_factory(model_cfg, probe_seq_len): |
| 110 | + def _probe(): |
| 111 | + probe_input = SingleBenchmarkRunInput( |
| 112 | + x=probe_seq_len, |
| 113 | + kernel_provider="pytorch", |
| 114 | + extra_benchmark_config={ |
| 115 | + "N": 8, |
| 116 | + "bsz": 1, |
| 117 | + "hidden_size": model_cfg.hidden_size, |
| 118 | + "dtype": model_cfg.dtype, |
| 119 | + "eps": 1e-6, |
| 120 | + }, |
| 121 | + ) |
| 122 | + V, fn = _setup_attn_res(probe_input) |
| 123 | + return fn() |
| 124 | + |
| 125 | + return _probe |
| 126 | + |
| 127 | + sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) |
| 128 | + |
| 129 | + model_configs_info = { |
| 130 | + cfg.name: { |
| 131 | + "hidden_size": cfg.hidden_size, |
| 132 | + "dtype": cfg.dtype, |
| 133 | + } |
| 134 | + for cfg in sweep.model_configs |
| 135 | + } |
| 136 | + |
| 137 | + common_configs = { |
| 138 | + "kernel_name": "attn_res", |
| 139 | + "x_name": "model_config", |
| 140 | + "x_label": "model configuration", |
| 141 | + "x_values": [cfg.name for cfg in sweep.model_configs], |
| 142 | + "kernel_providers": ["liger", "pytorch"], |
| 143 | + "extra_benchmark_configs": [ |
| 144 | + { |
| 145 | + "model_configs": model_configs_info, |
| 146 | + "N": 8, |
| 147 | + "bsz": sweep.batch_size, |
| 148 | + "seq_len": sweep.seq_len, |
| 149 | + "eps": 1e-6, |
| 150 | + } |
| 151 | + ], |
| 152 | + "overwrite": args.overwrite, |
| 153 | + } |
| 154 | + |
| 155 | + run_benchmarks( |
| 156 | + bench_test_fn=bench_speed_attn_res_model_config, |
| 157 | + kernel_operation_modes=["full", "forward", "backward"], |
| 158 | + metric_name="speed", |
| 159 | + metric_unit="ms", |
| 160 | + **common_configs, |
| 161 | + ) |
| 162 | + run_benchmarks( |
| 163 | + bench_test_fn=bench_memory_attn_res_model_config, |
| 164 | + kernel_operation_modes=["full", "forward", "backward"], |
| 165 | + metric_name="memory", |
| 166 | + metric_unit="MB", |
| 167 | + **common_configs, |
| 168 | + ) |
| 169 | + else: |
| 170 | + model = get_benchmark_model_config(args.model) |
| 171 | + probe_seq_len = 1024 |
| 172 | + |
| 173 | + def _probe(): |
| 174 | + probe_input = SingleBenchmarkRunInput( |
| 175 | + x=probe_seq_len, |
| 176 | + kernel_provider="pytorch", |
| 177 | + extra_benchmark_config={ |
| 178 | + "N": 8, |
| 179 | + "bsz": 1, |
| 180 | + "hidden_size": model.hidden_size, |
| 181 | + "dtype": model.dtype, |
| 182 | + "eps": 1e-6, |
| 183 | + }, |
| 184 | + ) |
| 185 | + V, fn = _setup_attn_res(probe_input) |
| 186 | + return fn() |
| 187 | + |
| 188 | + peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) |
| 189 | + kernel_bpt = peak_bytes // probe_seq_len |
| 190 | + |
| 191 | + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) |
| 192 | + |
| 193 | + common_configs = { |
| 194 | + "kernel_name": "attn_res", |
| 195 | + "x_name": "T", |
| 196 | + "x_label": "sequence length", |
| 197 | + "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], |
| 198 | + "kernel_providers": ["liger", "pytorch"], |
| 199 | + "extra_benchmark_configs": [ |
| 200 | + { |
| 201 | + "N": 8, |
| 202 | + "bsz": config.batch_size, |
| 203 | + "hidden_size": model.hidden_size, |
| 204 | + "dtype": model.dtype, |
| 205 | + "eps": 1e-6, |
| 206 | + } |
| 207 | + ], |
| 208 | + "overwrite": args.overwrite, |
| 209 | + } |
| 210 | + |
| 211 | + run_benchmarks( |
| 212 | + bench_test_fn=bench_speed_attn_res, |
| 213 | + kernel_operation_modes=["full", "forward", "backward"], |
| 214 | + metric_name="speed", |
| 215 | + metric_unit="ms", |
| 216 | + **common_configs, |
| 217 | + ) |
| 218 | + run_benchmarks( |
| 219 | + bench_test_fn=bench_memory_attn_res, |
| 220 | + kernel_operation_modes=["full", "forward", "backward"], |
| 221 | + metric_name="memory", |
| 222 | + metric_unit="MB", |
| 223 | + **common_configs, |
| 224 | + ) |
0 commit comments