Skip to content

Commit 06bc894

Browse files
Add torch.compile() option
1 parent 82a8161 commit 06bc894

11 files changed

Lines changed: 223 additions & 56 deletions

benchmarks/copy_blocks_benchmark.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@
9898
is_flag=True,
9999
help="Flag for printing results in CSV format",
100100
)
101+
@click.option(
102+
"--compile-ref",
103+
is_flag=True,
104+
help="Flag to torch.compile() the reference impl",
105+
)
106+
@click.option(
107+
"--compile-conch",
108+
is_flag=True,
109+
help="Flag to torch.compile() the Conch impl",
110+
)
101111
def main(
102112
head_dim: int,
103113
num_layers: int,
@@ -111,6 +121,8 @@ def main(
111121
verbose: bool,
112122
gpu: str,
113123
csv: bool,
124+
compile_ref: bool,
125+
compile_conch: bool,
114126
) -> None:
115127
"""Benchmark Conch copy_blocks operation.
116128
@@ -127,6 +139,8 @@ def main(
127139
verbose: Flag to indicate whether or not to print verbose output.
128140
gpu: Which gpu to run on.
129141
csv: Flag to indicate whether or not to print results in CSV format.
142+
compile_ref: Flag to torch.compile() the reference implementation.
143+
compile_conch: Flag to torch.compile() the Conch implementation.
130144
"""
131145
seed: Final = 0
132146
seed_everything(seed)
@@ -179,10 +193,13 @@ def main(
179193
# Convert mapping list to tensor
180194
block_mapping_tensor = torch.tensor(block_mapping, dtype=torch.int64, device=device).view(-1, 2)
181195

196+
copy_blocks_ref_fn = torch.compile(copy_blocks_reference) if compile_ref else copy_blocks_reference
197+
copy_blocks_conch_fn = torch.compile(copy_blocks_conch) if compile_conch else copy_blocks_conch
198+
182199
# Run the reference implementation.
183-
copy_blocks_reference(cloned_key_caches, cloned_value_caches, block_mapping)
200+
copy_blocks_ref_fn(cloned_key_caches, cloned_value_caches, block_mapping)
184201
# Call Conch kernel
185-
copy_blocks_conch(key_caches, value_caches, block_mapping_tensor)
202+
copy_blocks_conch_fn(key_caches, value_caches, block_mapping_tensor)
186203

187204
# Compare the results.
188205
num_key_matched = 0
@@ -215,7 +232,7 @@ def main(
215232

216233
# Benchmark Reference vs. Conch implementations
217234
baseline_result = benchmark_it(
218-
lambda: copy_blocks_reference(
235+
lambda: copy_blocks_ref_fn(
219236
cloned_key_caches,
220237
cloned_value_caches,
221238
block_mapping,
@@ -227,7 +244,7 @@ def main(
227244
)
228245

229246
conch_result = benchmark_it(
230-
lambda: copy_blocks_conch(
247+
lambda: copy_blocks_conch_fn(
231248
key_caches,
232249
value_caches,
233250
block_mapping_tensor,

benchmarks/fused_add_rms_norm_benchmark.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@
6262
is_flag=True,
6363
help="Flag for printing results in CSV format",
6464
)
65+
@click.option(
66+
"--compile-ref",
67+
is_flag=True,
68+
help="Flag to torch.compile() the reference impl",
69+
)
70+
@click.option(
71+
"--compile-conch",
72+
is_flag=True,
73+
help="Flag to torch.compile() the Conch impl",
74+
)
6575
def main( # noqa: PLR0913
6676
hidden_size: int,
6777
num_tokens: int,
@@ -70,6 +80,8 @@ def main( # noqa: PLR0913
7080
verbose: bool,
7181
gpu: str,
7282
csv: bool,
83+
compile_ref: bool,
84+
compile_conch: bool,
7385
) -> None:
7486
"""Benchmark Conch rms_norm op.
7587
@@ -81,6 +93,8 @@ def main( # noqa: PLR0913
8193
verbose: Flag to indicate whether or not to print verbose output.
8294
gpu: Which gpu to run on.
8395
csv: Flag to indicate whether or not to print results in CSV format.
96+
compile_ref: Flag to torch.compile() the reference implementation.
97+
compile_conch: Flag to torch.compile() the Conch implementation.
8498
"""
8599
seed: Final = 0
86100
seed_everything(seed)
@@ -110,9 +124,14 @@ def main( # noqa: PLR0913
110124
conch_residual = residual.clone()
111125
ref_residual = residual.clone()
112126

113-
conch_output, conch_residual = fused_add_rms_norm_conch(conch_x, conch_residual, weight, epsilon)
127+
fused_add_rms_norm_conch_fn = torch.compile(fused_add_rms_norm_conch) if compile_conch else fused_add_rms_norm_conch
128+
fused_add_rms_norm_ref_fn = (
129+
torch.compile(fused_add_rms_norm_reference) if compile_ref else fused_add_rms_norm_reference
130+
)
131+
132+
conch_output, conch_residual = fused_add_rms_norm_conch_fn(conch_x, conch_residual, weight, epsilon)
114133

115-
ref_output, ref_residual = fused_add_rms_norm_reference(ref_x, ref_residual, weight, epsilon)
134+
ref_output, ref_residual = fused_add_rms_norm_ref_fn(ref_x, ref_residual, weight, epsilon)
116135

117136
if not torch.allclose(ref_output, conch_output, atol=tolerance, rtol=tolerance):
118137
print(f"WARNING: Reference and Conch results differ! (atol={tolerance}, rtol={tolerance})", file=sys.stderr)
@@ -136,7 +155,7 @@ def main( # noqa: PLR0913
136155

137156
# Benchmark Reference vs. Conch implementations
138157
baseline_result = benchmark_it(
139-
lambda: fused_add_rms_norm_reference(
158+
lambda: fused_add_rms_norm_ref_fn(
140159
ref_x,
141160
ref_residual,
142161
weight,
@@ -149,7 +168,7 @@ def main( # noqa: PLR0913
149168
)
150169

151170
conch_result = benchmark_it(
152-
lambda: fused_add_rms_norm_conch(
171+
lambda: fused_add_rms_norm_conch_fn(
153172
conch_x,
154173
conch_residual,
155174
weight,

benchmarks/gelu_tanh_and_mul_benchmark.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@
6969
is_flag=True,
7070
help="Flag for printing results in CSV format",
7171
)
72+
@click.option(
73+
"--compile-ref",
74+
is_flag=True,
75+
help="Flag to torch.compile() the reference impl",
76+
)
77+
@click.option(
78+
"--compile-conch",
79+
is_flag=True,
80+
help="Flag to torch.compile() the Conch impl",
81+
)
7282
def main(
7383
hidden_size: int,
7484
num_tokens: int,
@@ -78,6 +88,8 @@ def main(
7888
verbose: bool,
7989
gpu: str,
8090
csv: bool,
91+
compile_ref: bool,
92+
compile_conch: bool,
8193
) -> None:
8294
"""Benchmark Conch GeluTanhAndMul op.
8395
@@ -90,6 +102,8 @@ def main(
90102
verbose: Flag to indicate whether or not to print verbose output.
91103
gpu: Which gpu to run on.
92104
csv: Flag for printing results in CSV format.
105+
compile_ref: Flag to torch.compile() the reference implementation.
106+
compile_conch: Flag to torch.compile() the Conch implementation.
93107
"""
94108
seed: Final = 0
95109
seed_everything(seed)
@@ -107,8 +121,13 @@ def main(
107121

108122
projections = torch.rand((num_tokens, hidden_size * 2), device=device)
109123

110-
ref_output = gelu_tanh_and_mul_reference(projections)
111-
conch_output = gelu_tanh_and_mul_conch(projections)
124+
gelu_tanh_and_mul_ref_fn = (
125+
torch.compile(gelu_tanh_and_mul_reference) if compile_ref else gelu_tanh_and_mul_reference
126+
)
127+
gelu_tanh_and_mul_conch_fn = torch.compile(gelu_tanh_and_mul_conch) if compile_conch else gelu_tanh_and_mul_conch
128+
129+
ref_output = gelu_tanh_and_mul_ref_fn(projections)
130+
conch_output = gelu_tanh_and_mul_conch_fn(projections)
112131

113132
if not torch.allclose(ref_output, conch_output, atol=absolute_tolerance):
114133
print(f"WARNING: Reference and Conch results differ! (atol={absolute_tolerance})", file=sys.stderr)
@@ -121,15 +140,15 @@ def main(
121140
print(f"Results matched with atol={absolute_tolerance} :)", file=sys.stderr)
122141

123142
baseline_result = benchmark_it(
124-
lambda: gelu_tanh_and_mul_reference(projections),
143+
lambda: gelu_tanh_and_mul_ref_fn(projections),
125144
tag="Baseline",
126145
metadata=metadata,
127146
iteration_time_ms=iteration_time_ms,
128147
warmup_time_ms=warmup_time_ms,
129148
)
130149

131150
conch_result = benchmark_it(
132-
lambda: gelu_tanh_and_mul_conch(projections),
151+
lambda: gelu_tanh_and_mul_conch_fn(projections),
133152
tag="Conch",
134153
metadata=metadata,
135154
iteration_time_ms=iteration_time_ms,

benchmarks/gemma_rms_norm_benchmark.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,16 @@
6969
is_flag=True,
7070
help="Flag for printing results in CSV format",
7171
)
72+
@click.option(
73+
"--compile-ref",
74+
is_flag=True,
75+
help="Flag to torch.compile() the reference impl",
76+
)
77+
@click.option(
78+
"--compile-conch",
79+
is_flag=True,
80+
help="Flag to torch.compile() the Conch impl",
81+
)
7282
def main(
7383
embedding_size: int,
7484
num_tokens: int,
@@ -78,6 +88,8 @@ def main(
7888
verbose: bool,
7989
gpu: str,
8090
csv: bool,
91+
compile_ref: bool,
92+
compile_conch: bool,
8193
) -> None:
8294
"""Benchmark Conch GemmaRMSNorm op.
8395
@@ -90,6 +102,8 @@ def main(
90102
verbose: Flag to indicate whether or not to print verbose output.
91103
gpu: Which gpu to run on.
92104
csv: Flag for printing results in CSV format.
105+
compile_ref: Flag to torch.compile() the reference implementation.
106+
compile_conch: Flag to torch.compile() the Conch implementation.
93107
"""
94108
seed: Final = 0
95109
seed_everything(seed)
@@ -113,8 +127,11 @@ def main(
113127
x_ref = x.clone()
114128
x_conch = x.clone()
115129

116-
result_ref = gemma_rms_norm_reference(x_ref, weights, epsilon, residual=None)
117-
result_conch = gemma_rms_norm_conch(x_conch, weights, epsilon, residual=None)
130+
gemma_rms_norm_ref_fn = torch.compile(gemma_rms_norm_reference) if compile_ref else gemma_rms_norm_reference
131+
gemma_rms_norm_conch_fn = torch.compile(gemma_rms_norm_conch) if compile_conch else gemma_rms_norm_conch
132+
133+
result_ref = gemma_rms_norm_ref_fn(x_ref, weights, epsilon, residual=None)
134+
result_conch = gemma_rms_norm_conch_fn(x_conch, weights, epsilon, residual=None)
118135

119136
# For mypy (if residual==None then result is single Tensor, not tuple[Tensor, Tensor])
120137
assert isinstance(result_ref, torch.Tensor)
@@ -131,15 +148,15 @@ def main(
131148
print(f"Results matched with atol={absolute_tolerance} :)", file=sys.stderr)
132149

133150
baseline_result = benchmark_it(
134-
lambda: gemma_rms_norm_reference(x_ref, weights, epsilon, residual=None),
151+
lambda: gemma_rms_norm_ref_fn(x_ref, weights, epsilon, residual=None),
135152
tag="Baseline",
136153
metadata=metadata,
137154
iteration_time_ms=iteration_time_ms,
138155
warmup_time_ms=warmup_time_ms,
139156
)
140157

141158
conch_result = benchmark_it(
142-
lambda: gemma_rms_norm_conch(x_conch, weights, epsilon, residual=None),
159+
lambda: gemma_rms_norm_conch_fn(x_conch, weights, epsilon, residual=None),
143160
tag="Conch",
144161
metadata=metadata,
145162
iteration_time_ms=iteration_time_ms,

benchmarks/reshape_and_cache_benchmark.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,16 @@
9898
is_flag=True,
9999
help="Flag for printing results in CSV format",
100100
)
101+
@click.option(
102+
"--compile-ref",
103+
is_flag=True,
104+
help="Flag to torch.compile() the reference impl",
105+
)
106+
@click.option(
107+
"--compile-conch",
108+
is_flag=True,
109+
help="Flag to torch.compile() the Conch impl",
110+
)
101111
def main(
102112
head_dim: int,
103113
num_tokens: int,
@@ -111,6 +121,8 @@ def main(
111121
verbose: bool,
112122
gpu: str,
113123
csv: bool,
124+
compile_ref: bool,
125+
compile_conch: bool,
114126
) -> None:
115127
"""Benchmark Conch reshape_and_cache.
116128
@@ -127,6 +139,8 @@ def main(
127139
verbose: Flag to indicate whether or not to print verbose output.
128140
gpu: Which gpu to run on.
129141
csv: Flag for printing results in CSV format.
142+
compile_ref: Flag to torch.compile() the reference implementation.
143+
compile_conch: Flag to torch.compile() the Conch implementation.
130144
"""
131145
if kv_cache_dtype != "auto" and not current_platform.supports_fp8():
132146
error_msg = "Cannot use FP8 KV Cache because current platform does not support FP8"
@@ -184,13 +198,16 @@ def main(
184198
key_cache_conch = key_cache_ref.clone()
185199
value_cache_conch = value_cache_ref.clone()
186200

187-
# Run the reference implementation.
188-
reshape_and_cache_reference(
189-
key, value, key_cache_ref, value_cache_ref, slot_mapping, kv_cache_dtype, k_scale, v_scale
201+
reshape_and_cache_ref_fn = (
202+
torch.compile(reshape_and_cache_reference) if compile_ref else reshape_and_cache_reference
190203
)
204+
reshape_and_cache_conch_fn = torch.compile(reshape_and_cache_conch) if compile_conch else reshape_and_cache_conch
205+
206+
# Run the reference implementation.
207+
reshape_and_cache_ref_fn(key, value, key_cache_ref, value_cache_ref, slot_mapping, kv_cache_dtype, k_scale, v_scale)
191208

192209
# Call Conch kernel
193-
reshape_and_cache_conch(
210+
reshape_and_cache_conch_fn(
194211
key, value, key_cache_conch, value_cache_conch, slot_mapping, kv_cache_dtype, k_scale, v_scale
195212
)
196213

@@ -230,7 +247,7 @@ def main(
230247

231248
# Benchmark Reference vs. Conch implementations
232249
baseline_result = benchmark_it(
233-
lambda: reshape_and_cache_reference(
250+
lambda: reshape_and_cache_ref_fn(
234251
key,
235252
value,
236253
key_cache_ref,
@@ -247,7 +264,7 @@ def main(
247264
)
248265

249266
conch_result = benchmark_it(
250-
lambda: reshape_and_cache_conch(
267+
lambda: reshape_and_cache_conch_fn(
251268
key,
252269
value,
253270
key_cache_conch,

0 commit comments

Comments
 (0)