Skip to content

Commit b4cb0de

Browse files
committed
Clean up param names, tune skips, and try to speed up resumption
Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent b0473cb commit b4cb0de

3 files changed

Lines changed: 43 additions & 19 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -507,12 +507,8 @@ def predict(
507507
resume_if_exists=True,
508508
resume_ignore_no_checkpoint=False,
509509
resume_past_end=False,
510-
restore_config=nl.RestoreConfig(
511-
path=str(ckpt_dir), # NeMo expects a string path.
512-
load_model_state=True,
513-
load_optim_state=False,
514-
load_artifacts=False,
515-
),
510+
resume_from_path=str(ckpt_dir),
511+
restore_config=None,
516512
)
517513
tokenizer = get_nmt_tokenizer("byte-level")
518514

sub-packages/bionemo-evo2/tests/bionemo/evo2/run/test_predict.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from bionemo.core.data.load import load
3131
from bionemo.llm.lightning import batch_collator
3232
from bionemo.testing.data.fasta import ALU_SEQUENCE, create_fasta_file
33+
from bionemo.testing.torch import check_fp8_support
3334

3435

3536
def is_a6000_gpu() -> bool:
@@ -74,11 +75,18 @@ def checkpoint_7b_1m_path() -> Path:
7475
@pytest.mark.parametrize(
7576
"ddp,pp,tp,wi",
7677
[
77-
(1, 1, 1, "epoch"),
78-
(2, 1, 1, "epoch"),
79-
(2, 1, 1, "batch"),
80-
(1, 2, 1, "epoch"),
81-
(1, 1, 2, "epoch"),
78+
pytest.param(1, 1, 1, "epoch", id="ddp=1,pp=1,tp=1,wi=epoch"),
79+
pytest.param(2, 1, 1, "epoch", id="ddp=2,pp=1,tp=1,wi=epoch"),
80+
pytest.param(2, 1, 1, "batch", id="ddp=2,pp=1,tp=1,wi=batch"),
81+
pytest.param(
82+
1,
83+
2,
84+
1,
85+
"epoch",
86+
id="ddp=1,pp=2,tp=1,wi=epoch",
87+
marks=pytest.mark.skip("Pipeline parallelism test currently hangs."),
88+
),
89+
pytest.param(1, 1, 2, "epoch", id="ddp=1,pp=1,tp=2,wi=epoch"),
8290
],
8391
)
8492
def test_predict_evo2_runs(
@@ -177,15 +185,29 @@ def test_predict_evo2_runs(
177185
@pytest.mark.parametrize(
178186
"ddp,cp,pp,tp,fp8,wi",
179187
[
180-
(1, 1, 1, 1, False, "epoch"),
181-
(2, 1, 1, 1, False, "epoch"),
182-
(2, 1, 1, 1, False, "batch"), # simulate a large prediction run with dp parallelism
183-
(1, 2, 1, 1, False, "epoch"),
184-
(1, 2, 1, 1, False, "batch"),
185-
(1, 1, 2, 1, False, "epoch"),
186-
(1, 1, 2, 1, True, "epoch"), # Cover case where FP8 was not supported with TP=2
187-
(1, 1, 1, 2, False, "epoch"),
188+
pytest.param(1, 1, 1, 1, False, "epoch", id="ddp=1,cp=1,pp=1,tp=1,fp8=False,wi=epoch"),
189+
pytest.param(2, 1, 1, 1, False, "epoch", id="ddp=2,cp=1,pp=1,tp=1,fp8=False,wi=epoch"),
190+
pytest.param(
191+
2, 1, 1, 1, False, "batch", id="ddp=2,cp=1,pp=1,tp=1,fp8=False,wi=batch"
192+
), # simulate a large prediction run with dp parallelism
193+
pytest.param(1, 2, 1, 1, False, "epoch", id="ddp=1,cp=2,pp=1,tp=1,fp8=False,wi=epoch"),
194+
pytest.param(1, 2, 1, 1, False, "batch", id="ddp=1,cp=2,pp=1,tp=1,fp8=False,wi=batch"),
195+
pytest.param(
196+
1,
197+
1,
198+
2,
199+
1,
200+
False,
201+
"epoch",
202+
id="ddp=1,cp=1,pp=2,tp=1,fp8=False,wi=epoch",
203+
marks=pytest.mark.skip("Pipeline parallelism test currently hangs."),
204+
),
205+
pytest.param(
206+
1, 1, 1, 2, True, "epoch", id="ddp=1,cp=1,pp=1,tp=2,fp8=True,wi=epoch"
207+
), # Cover case where FP8 was not supported with TP=2
208+
pytest.param(1, 1, 1, 2, False, "epoch", id="ddp=1,cp=1,pp=1,tp=2,fp8=False,wi=epoch"),
188209
],
210+
ids=lambda x: f"ddp={x[0]},cp={x[1]},pp={x[2]},tp={x[3]},fp8={x[4]},wi={x[5]}",
189211
)
190212
def test_predict_evo2_runs_with_log_probs(
191213
tmp_path,
@@ -210,6 +232,9 @@ def test_predict_evo2_runs_with_log_probs(
210232
world_size = ddp * cp * pp * tp
211233
if world_size > torch.cuda.device_count():
212234
pytest.skip(f"World size {world_size} is less than the number of GPUs {torch.cuda.device_count()}")
235+
is_fp8_supported, _, _ = check_fp8_support(torch.cuda.current_device())
236+
if not is_fp8_supported and fp8:
237+
pytest.skip("FP8 is not supported on this GPU.")
213238

214239
fasta_file_path = tmp_path / "test.fasta"
215240
create_fasta_file(
@@ -221,6 +246,7 @@ def test_predict_evo2_runs_with_log_probs(
221246
if is_a6000_gpu():
222247
# Fix hanging issue on A6000 GPUs with multi-gpu tests
223248
env["NCCL_P2P_DISABLE"] = "1"
249+
224250
fp8_option = "--fp8" if fp8 else ""
225251
# Build the command string.
226252
# Note: The command assumes that `train_evo2` is in your PATH.

sub-packages/bionemo-llm/src/bionemo/llm/utils/callbacks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import os
1818
from typing import Any, Literal, Sequence
19+
20+
1921
try: # Python 3.12+
2022
from typing import override
2123
except ImportError: # Python < 3.12

0 commit comments

Comments
 (0)