1616
1717import tempfile
1818
19+ import datetime
1920import numpy as np
2021import json
2122import unittest
2425import random
2526import os
2627import os .path
28+ import warnings
2729
2830from maxtext .common .gcloud_stub import is_decoupled
2931from maxtext .trainers .pre_train .train import main as train_main
@@ -52,6 +54,7 @@ def setUp(self):
5254
5355 @pytest .mark .integration_test
5456 @pytest .mark .tpu_only
57+ @pytest .mark .filterwarnings ("always::UserWarning" )
5558 def test_grad_accumulate_same_loss (self ):
5659 random_suffix = generate_random_string ()
5760 temp_dir = tempfile .gettempdir ()
@@ -71,6 +74,7 @@ def test_grad_accumulate_same_loss(self):
7174 "steps=20" ,
7275 ]
7376 # Run with gradient accumulation with accumulate_steps=10, per_device_batch=1 --> simulating per_device_batch=10
77+ start_time = datetime .datetime .now ()
7478 train_main (
7579 shared_maxtext_args
7680 + [
@@ -80,8 +84,12 @@ def test_grad_accumulate_same_loss(self):
8084 "gradient_accumulation_steps=10" ,
8185 ]
8286 )
87+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
88+ print (f"train with GA costs { seconds_elapsed } secs" )
89+ warnings .warn (f"DEUBG: train with GA costs { seconds_elapsed } secs" )
8390
8491 # Run without gradient accumulation with per_device_batch=10
92+ start_time = datetime .datetime .now ()
8593 train_main (
8694 shared_maxtext_args
8795 + [
@@ -91,8 +99,12 @@ def test_grad_accumulate_same_loss(self):
9199 "gradient_accumulation_steps=1" ,
92100 ]
93101 )
102+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
103+ print (f"train with regular costs { seconds_elapsed } secs" )
104+ warnings .warn (f"DEUBG: train with regular costs { seconds_elapsed } secs" )
94105
95106 # Assert losses roughly equal
107+ start_time = datetime .datetime .now ()
96108 with (
97109 open (run_accumulate_metrics_file , "rt" , encoding = "utf8" ) as accum_run ,
98110 open (run_regular_metrics_file , "rt" , encoding = "utf8" ) as regular_run ,
@@ -109,8 +121,12 @@ def test_grad_accumulate_same_loss(self):
109121 )
110122 # Not identical due to an epsilon addition in loss denominator.
111123 np .testing .assert_allclose (accum_run_loss , regular_run_loss , rtol = 0.01 )
124+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
125+ print (f"comparing losses costs { seconds_elapsed } secs" )
126+ warnings .warn (f"DEUBG: comparing losses costs { seconds_elapsed } secs" )
112127
113128 # Assert grad norms roughly equal
129+ start_time = datetime .datetime .now ()
114130 with (
115131 open (run_accumulate_metrics_file , "rt" , encoding = "utf8" ) as accum_run ,
116132 open (run_regular_metrics_file , "rt" , encoding = "utf8" ) as regular_run ,
@@ -127,8 +143,12 @@ def test_grad_accumulate_same_loss(self):
127143 )
128144 # Not identical due to an epsilon addition in loss denominator.
129145 np .testing .assert_allclose (accum_run_grad_norm , regular_run_grad_norm , rtol = 0.01 )
146+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
147+ print (f"comparing grad norms { seconds_elapsed } secs" )
148+ warnings .warn (f"DEUBG: comparing grad norms costs { seconds_elapsed } secs" )
130149
131150 # Assert per device tflops are the same (10x smaller microbatch size, but 10x more microbatches)
151+ start_time = datetime .datetime .now ()
132152 with (
133153 open (run_accumulate_metrics_file , "rt" , encoding = "utf8" ) as accum_run ,
134154 open (run_regular_metrics_file , "rt" , encoding = "utf8" ) as regular_run ,
@@ -144,6 +164,9 @@ def test_grad_accumulate_same_loss(self):
144164 flush = True ,
145165 )
146166 np .testing .assert_equal (accum_device_tflops , regular_device_tflops )
167+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
168+ print (f"comparing tflops { seconds_elapsed } secs" )
169+ warnings .warn (f"DEUBG: comparing tflops costs { seconds_elapsed } secs" )
147170
148171 @pytest .mark .integration_test
149172 @pytest .mark .tpu_only
0 commit comments