1414# limitations under the License.
1515
1616import os
17+ import random
1718import re
1819import shutil
1920import subprocess
4041)
4142
4243
44+ def extract_final_train_loss (output_text : str ) -> float :
45+ """
46+ Parse the training output to extract the final train_loss value.
47+
48+ Args:
49+ output_text: Combined stdout and stderr from training process
50+
51+ Returns:
52+ Final train_loss value as float
53+
54+ Raises:
55+ ValueError: If no train_loss found or parsing fails
56+ """
57+ # Look for dictionary-like patterns containing train_loss
58+ # Pattern matches: {'key': value, 'train_loss': value, ...}
59+ pattern = r'\{[^{}]*[\'"]train_loss[\'"]:\s*([0-9.]+)[^{}]*\}'
60+
61+ matches = re .findall (pattern , output_text )
62+
63+ if not matches :
64+ # Fallback: try to find train_loss in any context
65+ simple_pattern = r'[\'"]train_loss[\'"]:\s*([0-9.]+)'
66+ matches = re .findall (simple_pattern , output_text )
67+
68+ if not matches :
69+ raise ValueError ("No train_loss found in training output" )
70+
71+ # Return the last (final) train_loss value found
72+ final_train_loss = float (matches [- 1 ])
73+ return final_train_loss
74+
75+
4376def test_train_can_resume_from_checkpoint (monkeypatch , tmp_path : Path ):
4477 """Test that train.py runs successfully with sanity config and creates expected outputs."""
4578
@@ -51,11 +84,19 @@ def test_train_can_resume_from_checkpoint(monkeypatch, tmp_path: Path):
5184 monkeypatch .setenv ("RANK" , "0" )
5285 monkeypatch .setenv ("WORLD_SIZE" , "1" )
5386 monkeypatch .setenv ("MASTER_ADDR" , "localhost" )
54- monkeypatch .setenv ("MASTER_PORT" , "29500 " )
87+ monkeypatch .setenv ("MASTER_PORT" , f" { random . randint ( 20000 , 40000 ) } " )
5588 monkeypatch .setenv ("WANDB_MODE" , "disabled" )
5689
5790 with initialize_config_dir (config_dir = str (recipe_dir / "hydra_config" ), version_base = "1.2" ):
58- sanity_config = compose (config_name = "L0_sanity" , overrides = [f"trainer.output_dir={ tmp_path } " ])
91+ sanity_config = compose (
92+ config_name = "L0_sanity" ,
93+ overrides = [
94+ f"trainer.output_dir={ tmp_path } " ,
95+ "stop_after_n_steps=4" ,
96+ "trainer.do_eval=False" ,
97+ "trainer.save_steps=2" ,
98+ ],
99+ )
59100
60101 main (sanity_config )
61102
@@ -155,11 +196,14 @@ def test_accelerate_launch(accelerate_config, model_tag, tmp_path):
155196 str (accelerate_config_path ),
156197 "--num_processes" ,
157198 "1" ,
199+ "--main_process_port" ,
200+ f"{ random .randint (20000 , 40000 )} " ,
158201 str (train_py ),
159202 "--config-name" ,
160203 "L0_sanity.yaml" ,
161204 f"model_tag={ model_tag } " ,
162205 f"trainer.output_dir={ tmp_path } " ,
206+ "trainer.do_eval=False" ,
163207 ]
164208
165209 result = subprocess .run (
@@ -176,6 +220,17 @@ def test_accelerate_launch(accelerate_config, model_tag, tmp_path):
176220 print (f"STDERR:\n { result .stderr } " )
177221 pytest .fail (f"Command:\n { ' ' .join (cmd )} \n failed with exit code { result .returncode } " )
178222
223+ # Parse the training output to check final train_loss
224+ combined_output = result .stdout + result .stderr
225+ try :
226+ final_train_loss = extract_final_train_loss (combined_output )
227+ print (f"Final train_loss: { final_train_loss } " )
228+ assert final_train_loss < 3.0 , f"Final train_loss { final_train_loss } should be less than 3.0"
229+ except ValueError as e :
230+ print (f"STDOUT:\n { result .stdout } " )
231+ print (f"STDERR:\n { result .stderr } " )
232+ pytest .fail (f"Failed to extract train_loss from output: { e } " )
233+
179234
180235@requires_multi_gpu
181236@pytest .mark .parametrize (
@@ -211,11 +266,14 @@ def test_accelerate_launch_multi_gpu(accelerate_config, model_tag, tmp_path):
211266 str (accelerate_config_path ),
212267 "--num_processes" ,
213268 "2" ,
269+ "--main_process_port" ,
270+ f"{ random .randint (20000 , 40000 )} " ,
214271 str (train_py ),
215272 "--config-name" ,
216273 "L0_sanity.yaml" ,
217274 f"model_tag={ model_tag } " ,
218275 f"trainer.output_dir={ tmp_path } " ,
276+ "trainer.do_eval=False" ,
219277 ]
220278
221279 result = subprocess .run (
@@ -231,3 +289,15 @@ def test_accelerate_launch_multi_gpu(accelerate_config, model_tag, tmp_path):
231289 print (f"STDOUT:\n { result .stdout } " )
232290 print (f"STDERR:\n { result .stderr } " )
233291 pytest .fail (f"Command:\n { ' ' .join (cmd )} \n failed with exit code { result .returncode } " )
292+
293+ # Parse the training output to check final train_loss
294+ combined_output = result .stdout + result .stderr
295+ try :
296+ final_train_loss = extract_final_train_loss (combined_output )
297+ breakpoint ()
298+ print (f"Final train_loss: { final_train_loss } " )
299+ assert final_train_loss < 3.0 , f"Final train_loss { final_train_loss } should be less than 3.0"
300+ except ValueError as e :
301+ print (f"STDOUT:\n { result .stdout } " )
302+ print (f"STDERR:\n { result .stderr } " )
303+ pytest .fail (f"Failed to extract train_loss from output: { e } " )
0 commit comments