Skip to content

Commit 51c8a67

Browse files
committed
add partial conv tests to esm2_accelerate recipe
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent bc468a7 commit 51c8a67

3 files changed

Lines changed: 91 additions & 9 deletions

File tree

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
defaults:
22
- defaults
3+
- _self_
34

45
model_tag: "nvidia/esm2_t6_8M_UR50D"
5-
stop_after_n_steps: 4
6+
stop_after_n_steps: 250
7+
68
trainer:
79
run_name: "esm2_t6_8M_UR50D_sanity"
810
per_device_train_batch_size: 2
911
per_device_eval_batch_size: 2
10-
save_steps: 2
11-
eval_steps: 2
12-
logging_steps: 1
12+
save_steps: 1000
13+
eval_steps: 1000
14+
logging_steps: 10
1315
report_to: "none"
1416
dataloader_num_workers: 0
17+
warmup_steps: 0

recipes/esm2_accelerate/test_train.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import os
17+
import random
1718
import re
1819
import shutil
1920
import subprocess
@@ -40,6 +41,38 @@
4041
)
4142

4243

44+
def extract_final_train_loss(output_text: str) -> float:
45+
"""
46+
Parse the training output to extract the final train_loss value.
47+
48+
Args:
49+
output_text: Combined stdout and stderr from training process
50+
51+
Returns:
52+
Final train_loss value as float
53+
54+
Raises:
55+
ValueError: If no train_loss found or parsing fails
56+
"""
57+
# Look for dictionary-like patterns containing train_loss
58+
# Pattern matches: {'key': value, 'train_loss': value, ...}
59+
pattern = r'\{[^{}]*[\'"]train_loss[\'"]:\s*([0-9.]+)[^{}]*\}'
60+
61+
matches = re.findall(pattern, output_text)
62+
63+
if not matches:
64+
# Fallback: try to find train_loss in any context
65+
simple_pattern = r'[\'"]train_loss[\'"]:\s*([0-9.]+)'
66+
matches = re.findall(simple_pattern, output_text)
67+
68+
if not matches:
69+
raise ValueError("No train_loss found in training output")
70+
71+
# Return the last (final) train_loss value found
72+
final_train_loss = float(matches[-1])
73+
return final_train_loss
74+
75+
4376
def test_train_can_resume_from_checkpoint(monkeypatch, tmp_path: Path):
4477
"""Test that train.py runs successfully with sanity config and creates expected outputs."""
4578

@@ -51,11 +84,19 @@ def test_train_can_resume_from_checkpoint(monkeypatch, tmp_path: Path):
5184
monkeypatch.setenv("RANK", "0")
5285
monkeypatch.setenv("WORLD_SIZE", "1")
5386
monkeypatch.setenv("MASTER_ADDR", "localhost")
54-
monkeypatch.setenv("MASTER_PORT", "29500")
87+
monkeypatch.setenv("MASTER_PORT", f"{random.randint(20000, 40000)}")
5588
monkeypatch.setenv("WANDB_MODE", "disabled")
5689

5790
with initialize_config_dir(config_dir=str(recipe_dir / "hydra_config"), version_base="1.2"):
58-
sanity_config = compose(config_name="L0_sanity", overrides=[f"trainer.output_dir={tmp_path}"])
91+
sanity_config = compose(
92+
config_name="L0_sanity",
93+
overrides=[
94+
f"trainer.output_dir={tmp_path}",
95+
"stop_after_n_steps=4",
96+
"trainer.do_eval=False",
97+
"trainer.save_steps=2",
98+
],
99+
)
59100

60101
main(sanity_config)
61102

@@ -155,11 +196,14 @@ def test_accelerate_launch(accelerate_config, model_tag, tmp_path):
155196
str(accelerate_config_path),
156197
"--num_processes",
157198
"1",
199+
"--main_process_port",
200+
f"{random.randint(20000, 40000)}",
158201
str(train_py),
159202
"--config-name",
160203
"L0_sanity.yaml",
161204
f"model_tag={model_tag}",
162205
f"trainer.output_dir={tmp_path}",
206+
"trainer.do_eval=False",
163207
]
164208

165209
result = subprocess.run(
@@ -176,6 +220,17 @@ def test_accelerate_launch(accelerate_config, model_tag, tmp_path):
176220
print(f"STDERR:\n{result.stderr}")
177221
pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}")
178222

223+
# Parse the training output to check final train_loss
224+
combined_output = result.stdout + result.stderr
225+
try:
226+
final_train_loss = extract_final_train_loss(combined_output)
227+
print(f"Final train_loss: {final_train_loss}")
228+
assert final_train_loss < 3.0, f"Final train_loss {final_train_loss} should be less than 3.0"
229+
except ValueError as e:
230+
print(f"STDOUT:\n{result.stdout}")
231+
print(f"STDERR:\n{result.stderr}")
232+
pytest.fail(f"Failed to extract train_loss from output: {e}")
233+
179234

180235
@requires_multi_gpu
181236
@pytest.mark.parametrize(
@@ -211,11 +266,14 @@ def test_accelerate_launch_multi_gpu(accelerate_config, model_tag, tmp_path):
211266
str(accelerate_config_path),
212267
"--num_processes",
213268
"2",
269+
"--main_process_port",
270+
f"{random.randint(20000, 40000)}",
214271
str(train_py),
215272
"--config-name",
216273
"L0_sanity.yaml",
217274
f"model_tag={model_tag}",
218275
f"trainer.output_dir={tmp_path}",
276+
"trainer.do_eval=False",
219277
]
220278

221279
result = subprocess.run(
@@ -231,3 +289,15 @@ def test_accelerate_launch_multi_gpu(accelerate_config, model_tag, tmp_path):
231289
print(f"STDOUT:\n{result.stdout}")
232290
print(f"STDERR:\n{result.stderr}")
233291
pytest.fail(f"Command:\n{' '.join(cmd)}\nfailed with exit code {result.returncode}")
292+
293+
# Parse the training output to check final train_loss
294+
combined_output = result.stdout + result.stderr
295+
try:
296+
final_train_loss = extract_final_train_loss(combined_output)
297+
breakpoint()
298+
print(f"Final train_loss: {final_train_loss}")
299+
assert final_train_loss < 3.0, f"Final train_loss {final_train_loss} should be less than 3.0"
300+
except ValueError as e:
301+
print(f"STDOUT:\n{result.stdout}")
302+
print(f"STDERR:\n{result.stderr}")
303+
pytest.fail(f"Failed to extract train_loss from output: {e}")

recipes/esm2_accelerate/train.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ def main(args: DictConfig):
3838
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
3939
config.max_seq_length = args.max_seq_length
4040
config.micro_batch_size = args.trainer.per_device_train_batch_size
41-
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
41+
42+
model = AutoModelForMaskedLM.from_config(
43+
config,
44+
trust_remote_code=True,
45+
torch_dtype=torch.bfloat16,
46+
)
4247

4348
train_dataset, eval_dataset, data_collator = create_datasets_and_collator(
4449
tokenizer_name=args.model_tag,
@@ -57,7 +62,7 @@ def main(args: DictConfig):
5762
callbacks=[StopAfterNStepsCallback(args.stop_after_n_steps)],
5863
)
5964

60-
logger.info("ACCELERATE STATE:\n%s\n", trainer.accelerator.state)
65+
train_result, eval_result = None, None
6166

6267
if training_args.do_train:
6368
Path(training_args.output_dir).mkdir(parents=True, exist_ok=True)
@@ -72,7 +77,11 @@ def main(args: DictConfig):
7277
trainer.save_model(str(Path(training_args.output_dir) / "checkpoint-last"))
7378

7479
if training_args.do_eval:
75-
trainer.evaluate()
80+
eval_result = trainer.evaluate()
81+
logger.info("Evaluation complete. Metrics: %s", eval_result)
82+
trainer.save_metrics("eval", eval_result)
83+
84+
return train_result, eval_result
7685

7786

7887
if __name__ == "__main__":

0 commit comments

Comments
 (0)