forked from pytorch/torchtitan
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathflux.py
More file actions
executable file
·155 lines (130 loc) · 5.29 KB
/
Copy pathflux.py
File metadata and controls
executable file
·155 lines (130 loc) · 5.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import os
from torchtitan.tools.logging import logger
from tests.integration_tests import OverrideDefinitions
from tests.integration_tests.run_tests import _run_cmd
def build_flux_test_list() -> list[OverrideDefinitions]:
"""
key is the config file name and value is a list of OverrideDefinitions
that is used to generate variations of integration tests based on the
same root config file.
"""
integration_tests_flavors = [
OverrideDefinitions(
[
[
"--module flux",
"--config flux_debugmodel",
"--parallelism.data_parallel_shard_degree 2",
"--parallelism.data_parallel_replicate_degree 2",
"--parallelism.context_parallel_degree 2",
"--validator.enable",
"--validator.steps 5",
"--checkpoint.enable",
],
[],
],
"HSDP+CP+Validation+Inference",
"hsdp+cp+validation+inference",
ngpu=8,
),
OverrideDefinitions(
[
[
"--module flux",
"--config flux_debugmodel",
"--compile.enable",
],
],
"Flux FSDP+compile",
"flux_fsdp+compile",
),
]
return integration_tests_flavors
_TEST_SUITES_FUNCTION = {
"flux": build_flux_test_list,
}
def run_single_test(test_flavor: OverrideDefinitions, output_dir: str):
# run_test supports sequence of tests.
test_name = test_flavor.test_name
dump_folder_arg = f"--dump_folder {output_dir}/{test_name}"
# Random init encoder for offline testing
random_init_arg = "--tokenizer.test_mode --encoder.random_init"
clip_encoder_version_arg = (
"--encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/"
)
t5_encoder_version_arg = (
"--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/"
)
t5_tokenizer_path_arg = "--tokenizer.t5_tokenizer_path tests/assets/tokenizer"
clip_tokenizer_path_arg = "--tokenizer.clip_tokenizer_path tests/assets/tokenizer"
hf_assets_path_arg = "--hf_assets_path tests/assets/tokenizer"
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
for idx, override_arg in enumerate(test_flavor.override_args):
cmd = f"NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_train.sh"
# dump compile trace for debugging purpose
cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd
# save checkpoint (idx == 0) and load it for generation (idx == 1)
if test_name == "hsdp+cp+validation+inference" and idx == 1:
# For flux generation, test using inference script
cmd = (
f"NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} "
f"torchtitan/models/flux/run_infer.sh"
)
cmd += " " + dump_folder_arg
cmd += " " + random_init_arg
cmd += " " + clip_encoder_version_arg
cmd += " " + t5_encoder_version_arg
cmd += " " + t5_tokenizer_path_arg
cmd += " " + clip_tokenizer_path_arg
cmd += " " + hf_assets_path_arg
if override_arg:
cmd += " " + " ".join(override_arg)
logger.info(
f"=====Flux Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = _run_cmd(cmd)
logger.info(result.stdout)
if result.returncode != 0:
raise Exception(
f"Flux Integration test failed, flavor : {test_flavor.test_descr}, command : {cmd}"
)
def run_tests(args, test_list: list[OverrideDefinitions]):
"""Run all integration tests to test the core features of TorchTitan
Override the run_tests function in run_tests.py because FLUX model
uses different train.py in command to run the model"""
for test_flavor in test_list:
# Filter by test_name if specified
if args.test_name != "all" and test_flavor.test_name != args.test_name:
continue
# Check if we have enough GPUs
if args.ngpu < test_flavor.ngpu:
logger.info(
f"Skipping test {test_flavor.test_name} that requires {test_flavor.ngpu} gpus,"
f" because --ngpu arg is {args.ngpu}"
)
else:
run_single_test(test_flavor, args.output_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
parser.add_argument(
"--test_name",
default="all",
help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)",
)
parser.add_argument("--ngpu", default=8, type=int)
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if os.listdir(args.output_dir):
raise RuntimeError("Please provide an empty output directory.")
test_list = _TEST_SUITES_FUNCTION["flux"]()
run_tests(args, test_list)
if __name__ == "__main__":
main()