Skip to content

Commit 6a60786

Browse files
kjaniknvidiaclaude
andauthored
Enable pipeline model parallelism for Evo2 inference (#1478)
Remove the PP > 1 guard, argparse choices=[1] restriction, and hardcoded pre_process/post_process=True so the model provider auto-detects pipeline stage. Tested with PP=1, PP=2, and PP=5. ### Description For the most part I just removed the guarding that forces PP=1. There's only one functional line change. 1. Line 257 — Removed the if pipeline_model_parallel_size != 1: raise ValueError(...) guard (3 lines deleted) 2. Line 334 — Changed model_provider.provide(pre_process=True, post_process=True) to model_provider.provide() so each pipeline stage auto-detects whether it needs embedding/output layers 3. Line 508 — Removed choices=[1] from the --pipeline-model-parallel-size argparse argument 4. Lines 245, 553 — Updated docstrings removing "(must be 1)" #### Usage torchrun --nproc-per-node 2 /workspace/bionemo/src/bionemo/evo2/run/infer.py \ --ckpt-dir /workspace/bionemo/evo2_1b_8k_bf16_mbridge \ --prompt "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG" \ --max-new-tokens 10 \ --top-k 1 \ --temperature 1.0 \ --pipeline-model-parallel-size 2 torchrun --nproc-per-node 5 /workspace/bionemo/src/bionemo/evo2/run/infer.py \ --ckpt-dir /workspace/bionemo/evo2_1b_8k_bf16_mbridge \ --prompt "ATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCGATCG" \ --max-new-tokens 10 \ --top-k 1 \ --temperature 1.0 \ --pipeline-model-parallel-size 5 │ PP=1 inference (1 GPU) PASS ATCGATCGAT │ │ PP=2 inference (2 GPUs) PASS ATCGATCGAT │ │ PP=5 inference (5 GPUs) PASS ATCGATCGAT │ ### Type of changes <!-- Mark the relevant option with an [x] --> - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. #### Triggering Code Rabbit AI Review To trigger a code review from code rabbit, comment on a pull request with one of these commands: - @coderabbitai review - Triggers a standard review - @coderabbitai full review - Triggers a comprehensive review See https://docs.coderabbit.ai/reference/review-commands for a full list of commands. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [x] I have tested these changes locally - [x] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [x] All existing tests pass successfully --------- Signed-off-by: Ken Janik <kjanik@nvidia.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b5a98d2 commit 6a60786

2 files changed

Lines changed: 30 additions & 18 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/infer.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def setup_inference_engine(
242242
max_seq_length: Maximum sequence length for generation.
243243
max_batch_size: Maximum batch size for inference.
244244
tensor_parallel_size: Tensor parallelism degree.
245-
pipeline_model_parallel_size: Pipeline parallelism degree (must be 1).
245+
pipeline_model_parallel_size: Pipeline parallelism degree.
246246
context_parallel_size: Context parallelism degree.
247247
mixed_precision_recipe: Override mixed precision recipe.
248248
random_seed: Random seed for reproducibility.
@@ -254,9 +254,6 @@ def setup_inference_engine(
254254
>>> components = setup_inference_engine(Path("/path/to/checkpoint"), max_batch_size=4)
255255
>>> results = generate(components, prompts=["ATCG", "GCTA"], max_new_tokens=100)
256256
"""
257-
if pipeline_model_parallel_size != 1:
258-
raise ValueError("Pipeline parallelism > 1 is not supported for inference.")
259-
260257
# -------------------------------------------------------------------------
261258
# Step 1: Load configuration from checkpoint
262259
# -------------------------------------------------------------------------
@@ -334,7 +331,7 @@ def setup_inference_engine(
334331
logger.info("Creating model...")
335332
model_provider.finalize()
336333

337-
raw_model = model_provider.provide(pre_process=True, post_process=True).eval().cuda()
334+
raw_model = model_provider.provide().eval().cuda()
338335

339336
logger.info(f"Loading weights from: {resolved_ckpt_dir}")
340337
_load_model_weights_from_checkpoint(
@@ -505,7 +502,7 @@ def parse_args() -> argparse.Namespace:
505502

506503
# Parallelism arguments
507504
ap.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallelism")
508-
ap.add_argument("--pipeline-model-parallel-size", type=int, choices=[1], default=1, help="Pipeline parallelism")
505+
ap.add_argument("--pipeline-model-parallel-size", type=int, default=1, help="Pipeline parallelism")
509506
ap.add_argument("--context-parallel-size", type=int, default=1, help="Context parallelism")
510507

511508
# Output arguments
@@ -550,7 +547,7 @@ def infer(
550547
top_p: Nucleus sampling parameter (0 = disabled).
551548
seed: Random seed for reproducibility.
552549
tensor_parallel_size: Tensor parallelism degree.
553-
pipeline_model_parallel_size: Pipeline parallelism degree (must be 1).
550+
pipeline_model_parallel_size: Pipeline parallelism degree.
554551
context_parallel_size: Context parallelism degree.
555552
output_file: Optional path to save generated text.
556553
mixed_precision_recipe: Override mixed precision recipe.

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/run/test_infer.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ def run_infer_subprocess_parallel(
363363
top_k: int = 1,
364364
seed: int = 42,
365365
tensor_parallel_size: int = 1,
366+
pipeline_model_parallel_size: int = 1,
366367
context_parallel_size: int = 1,
367368
):
368369
"""Helper to run inference as a subprocess with model parallelism.
@@ -379,12 +380,13 @@ def run_infer_subprocess_parallel(
379380
top_k: Top-k sampling parameter (1 for greedy).
380381
seed: Random seed for reproducibility.
381382
tensor_parallel_size: Tensor parallelism degree.
383+
pipeline_model_parallel_size: Pipeline parallelism degree.
382384
context_parallel_size: Context parallelism degree.
383385
384386
Returns:
385387
The generated text from the output file.
386388
"""
387-
nproc_per_node = tensor_parallel_size * context_parallel_size
389+
nproc_per_node = tensor_parallel_size * pipeline_model_parallel_size * context_parallel_size
388390
open_port = find_free_network_port()
389391

390392
cmd = [
@@ -412,6 +414,8 @@ def run_infer_subprocess_parallel(
412414
str(seed),
413415
"--tensor-parallel-size",
414416
str(tensor_parallel_size),
417+
"--pipeline-model-parallel-size",
418+
str(pipeline_model_parallel_size),
415419
"--context-parallel-size",
416420
str(context_parallel_size),
417421
]
@@ -625,29 +629,39 @@ def mbridge_checkpoint_7b_1m_path(tmp_path_factory) -> Path:
625629
@pytest.mark.slow
626630
@pytest.mark.timeout(900)
627631
@pytest.mark.parametrize(
628-
"tp, cp",
632+
"tp, pp, cp",
629633
[
630634
# The 7b model has 32 attention heads, supporting TP=1, 2, 4, 8
631-
pytest.param(1, 1, id="tp=1,cp=1"),
632-
pytest.param(2, 1, id="tp=2,cp=1"),
633-
pytest.param(4, 1, id="tp=4,cp=1"),
634-
pytest.param(8, 1, id="tp=8,cp=1"),
635+
# TP-only configs
636+
pytest.param(1, 1, 1, id="tp=1,pp=1,cp=1"),
637+
pytest.param(2, 1, 1, id="tp=2,pp=1,cp=1"),
638+
pytest.param(4, 1, 1, id="tp=4,pp=1,cp=1"),
639+
pytest.param(8, 1, 1, id="tp=8,pp=1,cp=1"),
640+
# PP-only configs
641+
pytest.param(1, 2, 1, id="tp=1,pp=2,cp=1"),
642+
pytest.param(1, 4, 1, id="tp=1,pp=4,cp=1"),
643+
pytest.param(1, 8, 1, id="tp=1,pp=8,cp=1"),
644+
# Combined TP+PP configs
645+
pytest.param(2, 2, 1, id="tp=2,pp=2,cp=1"),
646+
pytest.param(4, 2, 1, id="tp=4,pp=2,cp=1"),
647+
# CP>1 configs (known broken)
635648
pytest.param(
649+
1,
636650
1,
637651
2,
638-
id="tp=1,cp=2",
652+
id="tp=1,pp=1,cp=2",
639653
marks=pytest.mark.xfail(reason="CP>1 is known broken for inference", strict=False),
640654
),
641655
],
642656
)
643657
@pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI")
644-
def test_parallel_inference_accuracy_7b(mbridge_checkpoint_7b_1m_path, tmp_path, dna_sequences, tp, cp):
658+
def test_parallel_inference_accuracy_7b(mbridge_checkpoint_7b_1m_path, tmp_path, dna_sequences, tp, pp, cp):
645659
"""Test that parallel inference with the 7b model produces accurate generation results.
646660
647-
Uses the 7b-1m checkpoint which supports TP>1 (32 attention heads), enabling
648-
proper tensor parallel accuracy testing that the 1b model cannot support.
661+
Uses the 7b-1m checkpoint which supports TP>1 (32 attention heads) and PP>1,
662+
enabling proper tensor and pipeline parallel accuracy testing.
649663
"""
650-
num_gpus_required = tp * cp
664+
num_gpus_required = tp * pp * cp
651665
if torch.cuda.device_count() < num_gpus_required:
652666
pytest.skip(f"Not enough GPUs: need {num_gpus_required}, have {torch.cuda.device_count()}")
653667

@@ -672,6 +686,7 @@ def test_parallel_inference_accuracy_7b(mbridge_checkpoint_7b_1m_path, tmp_path,
672686
top_k=1, # Greedy decoding
673687
seed=42,
674688
tensor_parallel_size=tp,
689+
pipeline_model_parallel_size=pp,
675690
context_parallel_size=cp,
676691
)
677692

0 commit comments

Comments
 (0)