Skip to content

Commit 77ecd7c

Browse files
author
Charles Li
committed
Add time stamp for GA test
1 parent 5478bad commit 77ecd7c

1 file changed

Lines changed: 23 additions & 0 deletions

File tree

tests/integration/gradient_accumulation_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import tempfile
1818

19+
import datetime
1920
import numpy as np
2021
import json
2122
import unittest
@@ -24,6 +25,7 @@
2425
import random
2526
import os
2627
import os.path
28+
import warnings
2729

2830
from maxtext.common.gcloud_stub import is_decoupled
2931
from maxtext.trainers.pre_train.train import main as train_main
@@ -52,6 +54,7 @@ def setUp(self):
5254

5355
@pytest.mark.integration_test
5456
@pytest.mark.tpu_only
57+
@pytest.mark.filterwarnings("always::UserWarning")
5558
def test_grad_accumulate_same_loss(self):
5659
random_suffix = generate_random_string()
5760
temp_dir = tempfile.gettempdir()
@@ -71,6 +74,7 @@ def test_grad_accumulate_same_loss(self):
7174
"steps=20",
7275
]
7376
# Run with gradient accumulation with accumulate_steps=10, per_device_batch=1 --> simulating per_device_batch=10
77+
start_time = datetime.datetime.now()
7478
train_main(
7579
shared_maxtext_args
7680
+ [
@@ -80,8 +84,12 @@ def test_grad_accumulate_same_loss(self):
8084
"gradient_accumulation_steps=10",
8185
]
8286
)
87+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
88+
print(f"train with GA costs {seconds_elapsed} secs")
89+
warnings.warn(f"DEUBG: train with GA costs {seconds_elapsed} secs")
8390

8491
# Run without gradient accumulation with per_device_batch=10
92+
start_time = datetime.datetime.now()
8593
train_main(
8694
shared_maxtext_args
8795
+ [
@@ -91,8 +99,12 @@ def test_grad_accumulate_same_loss(self):
9199
"gradient_accumulation_steps=1",
92100
]
93101
)
102+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
103+
print(f"train with regular costs {seconds_elapsed} secs")
104+
warnings.warn(f"DEUBG: train with regular costs {seconds_elapsed} secs")
94105

95106
# Assert losses roughly equal
107+
start_time = datetime.datetime.now()
96108
with (
97109
open(run_accumulate_metrics_file, "rt", encoding="utf8") as accum_run,
98110
open(run_regular_metrics_file, "rt", encoding="utf8") as regular_run,
@@ -109,8 +121,12 @@ def test_grad_accumulate_same_loss(self):
109121
)
110122
# Not identical due to an epsilon addition in loss denominator.
111123
np.testing.assert_allclose(accum_run_loss, regular_run_loss, rtol=0.01)
124+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
125+
print(f"comparing losses costs {seconds_elapsed} secs")
126+
warnings.warn(f"DEUBG: comparing losses costs {seconds_elapsed} secs")
112127

113128
# Assert grad norms roughly equal
129+
start_time = datetime.datetime.now()
114130
with (
115131
open(run_accumulate_metrics_file, "rt", encoding="utf8") as accum_run,
116132
open(run_regular_metrics_file, "rt", encoding="utf8") as regular_run,
@@ -127,8 +143,12 @@ def test_grad_accumulate_same_loss(self):
127143
)
128144
# Not identical due to an epsilon addition in loss denominator.
129145
np.testing.assert_allclose(accum_run_grad_norm, regular_run_grad_norm, rtol=0.01)
146+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
147+
print(f"comparing grad norms {seconds_elapsed} secs")
148+
warnings.warn(f"DEUBG: comparing grad norms costs {seconds_elapsed} secs")
130149

131150
# Assert per device tflops are the same (10x smaller microbatch size, but 10x more microbatches)
151+
start_time = datetime.datetime.now()
132152
with (
133153
open(run_accumulate_metrics_file, "rt", encoding="utf8") as accum_run,
134154
open(run_regular_metrics_file, "rt", encoding="utf8") as regular_run,
@@ -144,6 +164,9 @@ def test_grad_accumulate_same_loss(self):
144164
flush=True,
145165
)
146166
np.testing.assert_equal(accum_device_tflops, regular_device_tflops)
167+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
168+
print(f"comparing tflops {seconds_elapsed} secs")
169+
warnings.warn(f"DEUBG: comparing tflops costs {seconds_elapsed} secs")
147170

148171
@pytest.mark.integration_test
149172
@pytest.mark.tpu_only

0 commit comments

Comments
 (0)