@@ -84,69 +84,73 @@ def run_benchmarks(device_str):
8484 x_nan = torch .randn (10000 , 100 , device = device )
8585 mask = torch .rand_like (x_nan ) < 0.1
8686 x_nan [mask ] = float ('nan' )
87- benchmarks ['replace_nan_and_inf' ] = (math_utils .replace_nan_and_inf , (x_nan .clone (), 0 ))
87+ benchmarks ['replace_nan_and_inf' ] = (math_utils .replace_nan_and_inf , (x_nan .clone (), 0 ), True )
8888
8989 # angular_diff_batch
9090 a_ang = torch .randn (100000 , device = device )
9191 b_ang = torch .randn (100000 , device = device )
92- benchmarks ['angular_diff_batch' ] = (math_utils .angular_diff_batch , (a_ang , b_ang ))
92+ benchmarks ['angular_diff_batch' ] = (math_utils .angular_diff_batch , (a_ang , b_ang ), False )
9393
9494 # angle_between_stable
9595 u_abs = torch .randn (200 , 50 , device = device )
9696 v_abs = torch .randn (150 , 50 , device = device )
97- benchmarks ['angle_between_stable' ] = (math_utils .angle_between_stable , (u_abs , v_abs ))
97+ benchmarks ['angle_between_stable' ] = (math_utils .angle_between_stable , (u_abs , v_abs ), False )
9898
9999 # cos_sim_pairwise
100100 x1_cos = torch .randn (500 , 50 , device = device )
101101 x2_cos = torch .randn (300 , 50 , device = device )
102- benchmarks ['cos_sim_pairwise' ] = (math_utils .cos_sim_pairwise , (x1_cos , x2_cos ))
102+ benchmarks ['cos_sim_pairwise' ] = (math_utils .cos_sim_pairwise , (x1_cos , x2_cos ), False )
103103
104104 # batch_batch_product
105105 X_bbp = torch .randn (10000 , 20 , device = device )
106106 A_bbp = torch .randn (10000 , 20 , 20 , device = device )
107- benchmarks ['batch_batch_product' ] = (linalg .batch_batch_product , (X_bbp , A_bbp ))
107+ benchmarks ['batch_batch_product' ] = (linalg .batch_batch_product , (X_bbp , A_bbp ), False )
108108
109109 # batch_quadratic_product
110110 X_bqp = torch .randn (10000 , 20 , device = device )
111111 A_bqp = make_psd (20 , device )
112- benchmarks ['batch_quadratic_product' ] = (linalg .batch_quadratic_product , (X_bqp , A_bqp ))
112+ benchmarks ['batch_quadratic_product' ] = (linalg .batch_quadratic_product , (X_bqp , A_bqp ), False )
113113
114114 # batch_outer_product
115115 u_bop = torch .randn (10000 , 20 , device = device )
116116 v_bop = torch .randn (10000 , 20 , device = device )
117- benchmarks ['batch_outer_product' ] = (linalg .batch_outer_product , (u_bop , v_bop ))
117+ benchmarks ['batch_outer_product' ] = (linalg .batch_outer_product , (u_bop , v_bop ), False )
118118
119119 # squeeze_n
120120 x_sq = torch .randn (1 , 1 , 1 , 1000 , 50 , device = device )
121- benchmarks ['squeeze_n' ] = (lambda x : tensor_utils .squeeze_n ( x , 3 ), ( x_sq ,) )
121+ benchmarks ['squeeze_n' ] = (tensor_utils .squeeze_n , ( x_sq , 3 ), False )
122122
123123 # MinMaxScaler.transform
124124 x_mm = torch .randn (10000 , 50 , device = device )
125125 scaler = preprocess .MinMaxScaler ()
126126 scaler .fit (x_mm )
127- benchmarks ['MinMaxScaler.transform' ] = (scaler .transform , (x_mm ,))
127+ benchmarks ['MinMaxScaler.transform' ] = (scaler .transform , (x_mm ,), False )
128128
129129 # SoftKNN.forward
130130 x_knn = torch .randn (200 , 10 , device = device )
131131 knn = softknn .SoftKNN (min_k = 20 )
132- benchmarks ['SoftKNN.forward' ] = (knn , (x_knn ,))
132+ benchmarks ['SoftKNN.forward' ] = (knn , (x_knn ,), False )
133133
134134 # sqrtm (CPU only due to .numpy())
135135 if device_str == 'cpu' :
136136 A_sqrtm = make_psd (50 , device )
137- benchmarks ['sqrtm' ] = (linalg .sqrtm , (A_sqrtm ,))
137+ benchmarks ['sqrtm' ] = (linalg .sqrtm , (A_sqrtm ,), False )
138138
139139 # --- Run benchmarks ---
140140 print (f"\n { 'Function' :<30} { 'Eager (ms)' :>12} { 'Compile (ms)' :>14} { 'Speedup' :>10} { 'Compile OK' :>12} " )
141141 print ("-" * 80 )
142142
143- for name , (fn , args ) in benchmarks .items ():
143+ for name , (fn , args , needs_clone ) in benchmarks .items ():
144144 # Eager benchmark
145- # For replace_nan_and_inf, need fresh clone each call
146- if name == 'replace_nan_and_inf' :
147- def eager_fn (x_template = x_nan ):
148- return math_utils .replace_nan_and_inf (x_template .clone (), 0 )
149- eager_ms = bench (eager_fn , warmup = 5 , repeats = 20 , device = device_str )
145+ if needs_clone :
146+ # For in-place functions, clone first arg each call
147+ template = args [0 ]
148+ rest_args = args [1 :]
149+
150+ def cloning_fn (* a , _fn = fn , _tpl = template , _rest = rest_args ):
151+ return _fn (_tpl .clone (), * _rest )
152+
153+ eager_ms = bench (cloning_fn , warmup = 5 , repeats = 20 , device = device_str )
150154 else :
151155 try :
152156 eager_ms = bench (fn , * args , device = device_str )
@@ -157,10 +161,7 @@ def eager_fn(x_template=x_nan):
157161 continue
158162
159163 # Compile benchmark
160- if name == 'replace_nan_and_inf' :
161- compile_result = try_compile_bench (eager_fn , device = device_str )
162- else :
163- compile_result = try_compile_bench (fn , * args , device = device_str )
164+ compile_result = try_compile_bench (fn , * args , device = device_str )
164165
165166 if len (compile_result ) == 2 :
166167 compile_ms , compile_ok = compile_result
0 commit comments