Skip to content

Commit aabbf07

Browse files
author
Charles Li
committed
Add time info through user warning for debugging b/496201097
1 parent 77f5334 commit aabbf07

1 file changed

Lines changed: 11 additions & 0 deletions

File tree

tests/integration/aot_identical_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
training run (using train.py).
1919
"""
2020

21+
import datetime
2122
import tempfile
2223
import unittest
2324
import pytest
2425
import os
2526
import shutil
2627
import hashlib
28+
import warnings
2729
import re
2830
import jax
2931
from tests.utils.test_helpers import get_test_config_path
@@ -194,7 +196,11 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
194196
get_test_config_path(),
195197
f"dump_jaxpr_local_dir={train_dump_dir}",
196198
) + tuple(shared_args)
199+
start_time = datetime.datetime.now()
197200
train.main(train_argv)
201+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
202+
print(f"train cost {seconds_elapsed} secs")
203+
warnings.warn(f"DEUBG: train cost {seconds_elapsed} secs")
198204
jax.clear_caches()
199205

200206
# Run train_compile.py and dump jaxpr
@@ -206,7 +212,11 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
206212
f"compile_topology={topology}",
207213
"compile_topology_num_slices=1",
208214
) + tuple(shared_args)
215+
start_time = datetime.datetime.now()
209216
train_compile.main(compile_argv)
217+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
218+
print(f"train_compile cost {seconds_elapsed} secs")
219+
warnings.warn(f"DEUBG: train_compile cost {seconds_elapsed} secs")
210220
jax.clear_caches()
211221

212222
# Compare results
@@ -218,5 +228,6 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
218228
)
219229

220230
@pytest.mark.tpu_only
231+
@pytest.mark.filterwarnings("always::UserWarning")
221232
def test_default_jaxpr_match(self):
222233
self.assert_compile_and_real_match_jaxpr("default_run")

0 commit comments

Comments
 (0)