6969 is_flag = True ,
7070 help = "Flag for printing results in CSV format" ,
7171)
72+ @click .option (
73+ "--compile-ref" ,
74+ is_flag = True ,
75+ help = "Flag to torch.compile() the reference impl" ,
76+ )
77+ @click .option (
78+ "--compile-conch" ,
79+ is_flag = True ,
80+ help = "Flag to torch.compile() the Conch impl" ,
81+ )
7282def main (
7383 embedding_size : int ,
7484 num_tokens : int ,
@@ -78,6 +88,8 @@ def main(
7888 verbose : bool ,
7989 gpu : str ,
8090 csv : bool ,
91+ compile_ref : bool ,
92+ compile_conch : bool ,
8193) -> None :
8294 """Benchmark Conch GemmaRMSNorm op.
8395
@@ -90,6 +102,8 @@ def main(
90102 verbose: Flag to indicate whether or not to print verbose output.
91103 gpu: Which gpu to run on.
92104 csv: Flag for printing results in CSV format.
105+ compile_ref: Flag to torch.compile() the reference implementation.
106+ compile_conch: Flag to torch.compile() the Conch implementation.
93107 """
94108 seed : Final = 0
95109 seed_everything (seed )
@@ -113,8 +127,11 @@ def main(
113127 x_ref = x .clone ()
114128 x_conch = x .clone ()
115129
116- result_ref = gemma_rms_norm_reference (x_ref , weights , epsilon , residual = None )
117- result_conch = gemma_rms_norm_conch (x_conch , weights , epsilon , residual = None )
130+ gemma_rms_norm_ref_fn = torch .compile (gemma_rms_norm_reference ) if compile_ref else gemma_rms_norm_reference
131+ gemma_rms_norm_conch_fn = torch .compile (gemma_rms_norm_conch ) if compile_conch else gemma_rms_norm_conch
132+
133+ result_ref = gemma_rms_norm_ref_fn (x_ref , weights , epsilon , residual = None )
134+ result_conch = gemma_rms_norm_conch_fn (x_conch , weights , epsilon , residual = None )
118135
119136 # For mypy (if residual==None then result is single Tensor, not tuple[Tensor, Tensor])
120137 assert isinstance (result_ref , torch .Tensor )
@@ -131,15 +148,15 @@ def main(
131148 print (f"Results matched with atol={ absolute_tolerance } :)" , file = sys .stderr )
132149
133150 baseline_result = benchmark_it (
134- lambda : gemma_rms_norm_reference (x_ref , weights , epsilon , residual = None ),
151+ lambda : gemma_rms_norm_ref_fn (x_ref , weights , epsilon , residual = None ),
135152 tag = "Baseline" ,
136153 metadata = metadata ,
137154 iteration_time_ms = iteration_time_ms ,
138155 warmup_time_ms = warmup_time_ms ,
139156 )
140157
141158 conch_result = benchmark_it (
142- lambda : gemma_rms_norm_conch (x_conch , weights , epsilon , residual = None ),
159+ lambda : gemma_rms_norm_conch_fn (x_conch , weights , epsilon , residual = None ),
143160 tag = "Conch" ,
144161 metadata = metadata ,
145162 iteration_time_ms = iteration_time_ms ,
0 commit comments