2727from tokenizers import Tokenizer
2828from tokenizers .models import WordLevel
2929from tokenizers .pre_tokenizers import Whitespace
30- from transformers import PreTrainedTokenizerFast
31-
3230from train_ddp import main as main_ddp
3331from train_fsdp2 import main as main_fsdp2
32+ from transformers import PreTrainedTokenizerFast
3433
3534
3635os .environ ["WANDB_DISABLED" ] = "true"
@@ -111,15 +110,21 @@ def _assert_checkpoint_step(ckpt_subdir, step, num_ranks, is_ddp, use_distribute
111110 model_files = [f for f in files if f .startswith ("model_rank_" )]
112111 optimizer_files = [f for f in files if f .startswith ("optimizer_rank_" )]
113112 assert len (model_files ) >= num_ranks , f"Expected model files for { num_ranks } ranks in { step_dir } : { files } "
114- assert len (optimizer_files ) >= num_ranks , f"Expected optimizer files for { num_ranks } ranks in { step_dir } : { files } "
113+ assert len (optimizer_files ) >= num_ranks , (
114+ f"Expected optimizer files for { num_ranks } ranks in { step_dir } : { files } "
115+ )
115116 assert "metadata.pt" in files , f"Missing metadata.pt in { step_dir } : { files } "
116117 dataloader_files = [f for f in files if "dataloader" in f ]
117- assert len (dataloader_files ) >= num_ranks , f"Expected dataloader files for { num_ranks } ranks in { step_dir } : { files } "
118+ assert len (dataloader_files ) >= num_ranks , (
119+ f"Expected dataloader files for { num_ranks } ranks in { step_dir } : { files } "
120+ )
118121
119122
120123def _run_single_process_checkpoint_test (recipe_path , tmp_path , main_fn , ckpt_subdir_name , extra_overrides , is_ddp ):
121124 tokenizer_path = _create_local_tokenizer (tmp_path )
122- expert_parallel_size = int (next (o .split ("=" , 1 )[1 ] for o in extra_overrides if o .startswith ("expert_parallel_size=" )))
125+ expert_parallel_size = int (
126+ next (o .split ("=" , 1 )[1 ] for o in extra_overrides if o .startswith ("expert_parallel_size=" ))
127+ )
123128 use_distributed_checkpoint = is_ddp and expert_parallel_size > 1
124129 common = [
125130 "checkpoint.save_every_n_steps=5" ,
@@ -138,7 +143,9 @@ def _run_single_process_checkpoint_test(recipe_path, tmp_path, main_fn, ckpt_sub
138143 torch .cuda .empty_cache ()
139144
140145 ckpt_subdir = os .path .join (str (tmp_path / "ckpt" ), ckpt_subdir_name )
141- _assert_checkpoint_step (ckpt_subdir , 5 , num_ranks = 1 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint )
146+ _assert_checkpoint_step (
147+ ckpt_subdir , 5 , num_ranks = 1 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint
148+ )
142149
143150 cfg2 = _compose_config (
144151 recipe_path ,
@@ -149,7 +156,9 @@ def _run_single_process_checkpoint_test(recipe_path, tmp_path, main_fn, ckpt_sub
149156 gc .collect ()
150157 torch .cuda .empty_cache ()
151158
152- _assert_checkpoint_step (ckpt_subdir , 5 , num_ranks = 1 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint )
159+ _assert_checkpoint_step (
160+ ckpt_subdir , 5 , num_ranks = 1 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint
161+ )
153162 _assert_checkpoint_step (
154163 ckpt_subdir , 10 , num_ranks = 1 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint
155164 )
@@ -162,7 +171,9 @@ def _run_multi_process_checkpoint_test(
162171):
163172 ckpt_dir = str (tmp_path / "ckpt" )
164173 tokenizer_path = _create_local_tokenizer (tmp_path )
165- expert_parallel_size = int (next (o .split ("=" , 1 )[1 ] for o in extra_overrides if o .startswith ("expert_parallel_size=" )))
174+ expert_parallel_size = int (
175+ next (o .split ("=" , 1 )[1 ] for o in extra_overrides if o .startswith ("expert_parallel_size=" ))
176+ )
166177 use_distributed_checkpoint = is_ddp and expert_parallel_size > 1
167178 env = os .environ .copy ()
168179 env ["WANDB_MODE" ] = "disabled"
@@ -194,7 +205,9 @@ def _run_multi_process_checkpoint_test(
194205 assert result1 .returncode == 0 , f"Phase 1 failed: { result1 .stderr } "
195206
196207 ckpt_subdir = os .path .join (ckpt_dir , ckpt_subdir_name )
197- _assert_checkpoint_step (ckpt_subdir , 5 , num_ranks = 2 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint )
208+ _assert_checkpoint_step (
209+ ckpt_subdir , 5 , num_ranks = 2 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint
210+ )
198211
199212 result2 = subprocess .run (
200213 [* base_cmd , "num_train_steps=15" , "checkpoint.resume_from_checkpoint=true" , * common ],
@@ -205,7 +218,9 @@ def _run_multi_process_checkpoint_test(
205218 )
206219 assert result2 .returncode == 0 , f"Phase 2 failed: { result2 .stderr } "
207220
208- _assert_checkpoint_step (ckpt_subdir , 5 , num_ranks = 2 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint )
221+ _assert_checkpoint_step (
222+ ckpt_subdir , 5 , num_ranks = 2 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint
223+ )
209224 _assert_checkpoint_step (
210225 ckpt_subdir , 10 , num_ranks = 2 , is_ddp = is_ddp , use_distributed_checkpoint = use_distributed_checkpoint
211226 )
0 commit comments