Skip to content

Commit c8e277d

Browse files
Merge pull request #3800 from AI-Hypercomputer:add_resharding_integration_test
PiperOrigin-RevId: 910405797
2 parents abfdb42 + 7f7c2a9 commit c8e277d

1 file changed

Lines changed: 142 additions & 0 deletions

File tree

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Integration tests for checkpoint resharding functionality.
16+
17+
These tests verify that a training run saves a checkpoint using one mesh topology,
18+
and that a subsequent run can successfully restore and continue training using a
19+
different mesh topology (resharding).
20+
"""
21+
22+
from datetime import datetime
23+
import json
24+
from math import isclose
25+
import pytest
26+
27+
from maxtext.trainers.pre_train.train import main as train_main
28+
from tests.utils.test_helpers import (
29+
get_test_config_path,
30+
get_test_base_output_directory,
31+
get_test_dataset_path,
32+
)
33+
34+
35+
def get_resharding_command(run_date, steps, metrics_file, base_output_directory, dataset_path, parallelism_args):
36+
"""Generates a command list for the resharding test run."""
37+
model_params = [
38+
"base_emb_dim=384",
39+
"base_num_query_heads=8",
40+
"base_num_kv_heads=8",
41+
"base_mlp_dim=192",
42+
"base_num_decoder_layers=8",
43+
"head_dim=128",
44+
]
45+
46+
return (
47+
[
48+
None,
49+
get_test_config_path(),
50+
f"run_name=runner_{run_date}",
51+
f"steps={steps}",
52+
f"metrics_file={metrics_file}",
53+
f"base_output_directory={base_output_directory}",
54+
f"dataset_path={dataset_path}",
55+
"dataset_type=synthetic",
56+
"grain_worker_count=0",
57+
"collect_stack_trace=False",
58+
]
59+
+ model_params
60+
+ parallelism_args
61+
)
62+
63+
64+
def check_loss(metrics_file_suffix, target):
65+
"""Asserts that loss values match between saved and restored checkpoints.
66+
67+
Verifies the resharding restoration is mathematically consistent by comparing
68+
the final logged loss of the initial (saved) run against the initial logged
69+
loss of the resumed (restored) run within a relative tolerance.
70+
"""
71+
metrics_file_saved = "saved_" + metrics_file_suffix
72+
metrics_file_restored = "restored_" + metrics_file_suffix
73+
74+
with (
75+
open(metrics_file_saved, "rt", encoding="utf8") as saved,
76+
open(metrics_file_restored, "rt", encoding="utf8") as restored,
77+
):
78+
# Read the last line of the saved metrics to get the final pre-checkpoint loss
79+
saved_loss = json.loads(saved.readlines()[-1])[target]
80+
# Read the first line of the restored metrics to get the initial post-restoration loss
81+
restored_loss = json.loads(restored.readlines()[0])[target]
82+
83+
print("Saved loss: ", saved_loss)
84+
print("Restored loss: ", restored_loss)
85+
# Checks that checkpoint restore was successful by comparing loss of last
86+
# step in saved checkpoint to loss of first step in restored checkpoint
87+
assert isclose(saved_loss, restored_loss, rel_tol=0.1)
88+
89+
90+
@pytest.mark.integration_test
91+
@pytest.mark.tpu_only
92+
@pytest.mark.scheduled_only
93+
def test_checkpoint_resharding():
94+
"""Tests checkpoint resharding by saving and restoring with different mesh topologies."""
95+
run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
96+
base_output_directory = get_test_base_output_directory()
97+
dataset_path = get_test_dataset_path()
98+
99+
# Phase 1: Train and Save Checkpoint
100+
# Topology: FSDP=4, Tensor=1
101+
save_parallelism = [
102+
"checkpoint_period=10",
103+
"save_checkpoint_on_completion=True", # Saves Checkpoint 0 upon job completion (model state after step 0)
104+
"dcn_data_parallelism=1",
105+
"dcn_fsdp_parallelism=1",
106+
"ici_fsdp_parallelism=4",
107+
"ici_tensor_parallelism=1",
108+
]
109+
train_main(
110+
get_resharding_command(
111+
run_date,
112+
steps=1, # Executes Step 0
113+
metrics_file="saved_metrics.txt",
114+
base_output_directory=base_output_directory,
115+
dataset_path=dataset_path,
116+
parallelism_args=save_parallelism,
117+
)
118+
)
119+
120+
# Phase 2: Restore and Continue
121+
# Topology: FSDP=2, Tensor=2
122+
restore_parallelism = [
123+
"dcn_data_parallelism=1",
124+
"dcn_fsdp_parallelism=1",
125+
"ici_fsdp_parallelism=2",
126+
"ici_tensor_parallelism=2",
127+
]
128+
train_main(
129+
get_resharding_command(
130+
run_date,
131+
# 'steps' defines the target global step.
132+
# Restores Checkpoint 0 (state after step 0), sets start_step=1, and executes Step 1 to reach global step 2.
133+
steps=2,
134+
metrics_file="restored_metrics.txt",
135+
base_output_directory=base_output_directory,
136+
dataset_path=dataset_path,
137+
parallelism_args=restore_parallelism,
138+
)
139+
)
140+
141+
# Phase 3: Verify Loss Consistency
142+
check_loss("metrics.txt", "learning/loss")

0 commit comments

Comments
 (0)