1818training run (using train.py).
1919"""
2020
21+ import datetime
2122import tempfile
2223import unittest
2324import pytest
2425import os
2526import shutil
2627import hashlib
28+ import warnings
2729import re
2830import jax
2931from tests .utils .test_helpers import get_test_config_path
@@ -194,7 +196,11 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
194196 get_test_config_path (),
195197 f"dump_jaxpr_local_dir={ train_dump_dir } " ,
196198 ) + tuple (shared_args )
199+ start_time = datetime .datetime .now ()
197200 train .main (train_argv )
201+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
202+ print (f"train cost { seconds_elapsed } secs" )
203+ warnings .warn (f"DEUBG: train cost { seconds_elapsed } secs" )
198204 jax .clear_caches ()
199205
200206 # Run train_compile.py and dump jaxpr
@@ -206,7 +212,11 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
206212 f"compile_topology={ topology } " ,
207213 "compile_topology_num_slices=1" ,
208214 ) + tuple (shared_args )
215+ start_time = datetime .datetime .now ()
209216 train_compile .main (compile_argv )
217+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
218+ print (f"train_compile cost { seconds_elapsed } secs" )
219+ warnings .warn (f"DEUBG: train_compile cost { seconds_elapsed } secs" )
210220 jax .clear_caches ()
211221
212222 # Compare results
@@ -218,5 +228,6 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
218228 )
219229
220230 @pytest .mark .tpu_only
231+ @pytest .mark .filterwarnings ("always::UserWarning" )
221232 def test_default_jaxpr_match (self ):
222233 self .assert_compile_and_real_match_jaxpr ("default_run" )
0 commit comments