Skip to content

Commit 8d1bf96

Browse files
authored
comment clear_jax memory function (#83)
1 parent e3cf453 commit 8d1bf96

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

Ironwood/src/benchmark_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ def multiple_iteration_timeit_from_trace(
169169

170170
result = compute_func(*data_args)
171171
jax.block_until_ready(result)
172-
clear_jax_memory()
172+
173+
# Commenting it out as it's causing issues with GEMM
174+
# clear_jax_memory()
173175
trace = get_trace(tmp_trace_dir)
174176

175177
if trace_full_dir != tmp_trace_dir:

0 commit comments

Comments
 (0)