diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index 19a6642d..d0e72b9e 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -169,7 +169,9 @@ def multiple_iteration_timeit_from_trace( result = compute_func(*data_args) jax.block_until_ready(result) - clear_jax_memory() + + # Commenting it out as it's causing issues with GEMM + # clear_jax_memory() trace = get_trace(tmp_trace_dir) if trace_full_dir != tmp_trace_dir: