@@ -124,7 +124,7 @@ def add_calculate_metrics(
124124 total_bytes , SHARDING_STRATEGY
125125 )
126126 return unified_bytes_metrics (
127- m , n , time_ms_list , total_bytes , total_bytes_all_devices
127+ m , n , time_ms_list , total_bytes , total_bytes_all_devices , dtype = dtype . dtype . name
128128 )
129129
130130
@@ -191,7 +191,7 @@ def rmsnorm_calculate_metrics(
191191 total_bytes , SHARDING_STRATEGY
192192 )
193193 return unified_bytes_metrics (
194- m , n , time_ms_list , total_bytes , total_bytes_all_devices
194+ m , n , time_ms_list , total_bytes , total_bytes_all_devices , dtype = dtype . dtype . name
195195 )
196196
197197
@@ -264,7 +264,7 @@ def silu_mul_calculate_metrics(
264264 total_bytes , SHARDING_STRATEGY
265265 )
266266 return unified_bytes_metrics (
267- m , n , time_ms_list , total_bytes , total_bytes_all_devices
267+ m , n , time_ms_list , total_bytes , total_bytes_all_devices , dtype = dtype . dtype . name
268268 )
269269
270270
@@ -325,7 +325,7 @@ def sigmoid_calculate_metrics(
325325 total_bytes , SHARDING_STRATEGY
326326 )
327327 return unified_bytes_metrics (
328- m , n , time_ms_list , total_bytes , total_bytes_all_devices
328+ m , n , time_ms_list , total_bytes , total_bytes_all_devices , dtype = dtype . dtype . name
329329 )
330330
331331
@@ -413,4 +413,4 @@ def sigmoid_calculate_metrics(
413413# scale = 2 if dtype == jnp.bfloat16 else 1
414414# total_bytes = scale * 3 * m
415415# total_bytes, total_bytes_all_devices = handle_based_on_sharding(total_bytes, SHARDING_STRATEGY)
416- # return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices)
416+ # return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name )
0 commit comments