-
Notifications
You must be signed in to change notification settings - Fork 527
Expand file tree
/
Copy pathcheckpoint_compatibility_test.py
More file actions
98 lines (82 loc) · 3.34 KB
/
checkpoint_compatibility_test.py
File metadata and controls
98 lines (82 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integration tests to check compatibility of checkpoints between different input pipelines.
These tests verify that a checkpoint saved during a training run using one
input pipeline (e.g., 'grain') can be successfully restored and continued
by a subsequent training run using a different input pipeline (e.g., 'tfds').
The tests confirm restoration by checking the starting step of the resumed runs.
Note: Make sure to run
`bash src/dependencies/scripts/setup_gcsfuse.sh DATASET_GCS_BUCKET=gs://maxtext-dataset MOUNT_PATH=/tmp/gcsfuse/`
before running tests locally.
"""
from datetime import datetime
import json
import os
import pytest
from maxtext.trainers.pre_train.train import main as train_main
from maxtext.utils.globals import MAXTEXT_REPO_ROOT
from tests.integration.checkpointing_test import get_checkpointing_command
def check_start_step(metrics_file, start_step_target):
with open(metrics_file, "rt", encoding="utf8") as metrics:
start_step = json.loads(metrics.readlines()[0])["step"]
print(f"Start step is {start_step}, start step target is {start_step_target}")
assert start_step == float(start_step_target)
def run_checkpoint_compatibility(hardware, attention_type):
"""Tests checkpoint compatibility."""
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
grain_command = [
"grain_worker_count=0",
"grain_train_files=/tmp/gcsfuse/array-record/c4/en/3.0.1/c4-train.array_record*",
]
local_ckpt_dir = "/tmp/maxtext_local_output"
# Run training using grain input pipeline
train_main(
get_checkpointing_command(
run_date,
hardware=hardware,
steps=1,
metrics_file=os.path.join(MAXTEXT_REPO_ROOT, "run_1_metrics.txt"),
attention_type=attention_type,
dataset_type="grain",
dataset_path="/tmp/gcsfuse",
base_output_directory=local_ckpt_dir,
)
+ grain_command
)
# Resume training using tfds input pipeline
train_main(
get_checkpointing_command(
run_date,
hardware=hardware,
steps=2,
metrics_file=os.path.join(MAXTEXT_REPO_ROOT, "run_2_metrics.txt"),
attention_type=attention_type,
dataset_type="tfds",
dataset_path="/tmp/gcsfuse",
base_output_directory=local_ckpt_dir,
)
)
check_start_step(os.path.join(MAXTEXT_REPO_ROOT, "run_2_metrics.txt"), 1.0)
@pytest.mark.integration_test
@pytest.mark.tpu_only
@pytest.mark.skip(reason="Flaky test b/470704234")
def test_autoselected_attention():
run_checkpoint_compatibility("tpu", "autoselected")
@pytest.mark.external_training
@pytest.mark.integration_test
@pytest.mark.gpu_only
def test_with_dot_product():
os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention
run_checkpoint_compatibility("gpu", "dot_product")