Skip to content

megatron-lm: bump to NGC pytorch:26.02 and add Llama 3 8B sbatch#1071

Open
KeitaW wants to merge 5 commits into
mainfrom
kw/megatron-lm-pytorch26-b300
Open

megatron-lm: bump to NGC pytorch:26.02 and add Llama 3 8B sbatch#1071
KeitaW wants to merge 5 commits into
mainfrom
kw/megatron-lm-pytorch26-b300

Conversation

@KeitaW
Copy link
Copy Markdown
Collaborator

@KeitaW KeitaW commented Apr 29, 2026

Summary

Bumps the Megatron-LM test case to nvcr.io/nvidia/pytorch:26.02-py3 and adds
an sbatch tuned for Llama 3 8B on P6-B300. The base image ships CUDA 13 + a
recent NCCL with native sm_103 binaries, eliminating the PTX-JIT fallback
on Blackwell Ultra.

Stack delta

Component Before After
NGC PyTorch base 25.06-py3 26.02-py3
GDRCopy v2.5.1 v2.5.2
EFA installer 1.43.2 1.48.0
NCCL (ARG) (inherited) v2.30.4-1 (declared explicit)
AWS_OFI_NCCL (ARG) (inherited) v1.19.0 (declared explicit)
transformers 4.52.4 4.57.6
Megatron-LM core_v0.12.1 core_v0.17.0

NCCL_VERSION and AWS_OFI_NCCL_VERSION are declared explicitly so the
repo's CI version-gate (which greps the Dockerfile for nccl/efa lines
and parses versions) sees values at or above the enforced minimums
(EFA ≥ 1.47.0, NCCL ≥ 2.28).

New: slurm/llama3/pretrain-llama3-8b.sbatch

Drop-in companion to slurm/llama2/. Defaults tuned for 8× B300 per node:
TP=1, PP=1, CP=2, seq_length=8192, MBS=1 GBS=512, --bf16,
--use-flash-attn, --transformer-impl transformer_engine, Llama 3 RoPE
(--rotary-base 500000), HuggingFaceTokenizer.

Compatibility fixes (required for core_v0.17.0)

Two bugs that the new container surfaces, both fixed in this PR:

  1. pretrain-llama2.sbatch missing --eval-interval:
    core_v0.17.0's data iterator builder dereferences args.eval_interval
    even when --eval-iters 0
    (megatron/training/training.py:1143):

    eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
    TypeError: unsupported operand type(s) for //: 'int' and 'NoneType'
    

    Crashes before iter 1 with the new container. The shipped gpt3 sbatch
    already has --eval-interval 1000; this PR adds it to llama2 too.

  2. helpers_cpp C++ extension built lazily by rank 0 (Dockerfile fix):
    core_v0.17.0 lazy-builds megatron.core.datasets.helpers_cpp the
    first time a dataset is accessed. The build runs on rank 0 only, but
    /workspace inside each Pyxis container is local — when training spans
    multiple nodes, ranks on other nodes never see the rank-0 build and crash
    with ModuleNotFoundError: No module named 'megatron.core.datasets.helpers_cpp'. Fixed by pre-building the
    extension during docker build (passes 1n; was the only barrier to
    multi-node post runs).

Performance

Hardware: SageMaker HyperPod Slurm cluster, p6-b300.48xlarge (B300, sm_103),
8 GPUs/node. Llama 3 8B-equivalent (pretrain_gpt.py with the new sbatch
shape: 32L × 4096H × 32A, GQA 8, seq_len=8192, MBS=1 GBS=512, CP=2,
mock data), --log-throughput. Steady-state averaged over iters 5–25.

Nodes GPUs Pre TFLOPS/GPU Post TFLOPS/GPU Δ
1 8 710 809 +13.9%
4 32 LOAD_FAIL (NCCL/OFI hang) 782 n/a
8 64 LOAD_FAIL 760 n/a
16 128 LOAD_FAIL 740 n/a
32 256 LOAD_FAIL 732 n/a

Pre baseline (1n): 707–713 TFLOPS/GPU stable across iters 20–29 of a
100-iter run.
Post 1n: 800–820 TFLOPS/GPU steady state (mean 809).
Post 4–32n: 732–782 TFLOPS/GPU; the per-rank floor drops ~10% from 1n→32n,
which is consistent with CP=2 communication overhead growing with cluster
size — there's no grad-accumulation buffer to hide the all-gather, so ~700+
TFLOPS/GPU at 256 GPUs is the right neighborhood for this shape on B300.

Multi-node pre = LOAD_FAIL: All four pre runs (4/8/16/32n) hung at
> building train, validation, and test datasets for GPT ... with
NCCL WARN NET/OFI Attempt to call recv_close_deferred with outstanding requests and never reached iter 1. The old NCCL 2.27.3 + bundled OFI
plugin in NGC pytorch:25.06 doesn't survive multi-rank dataset-builder
collectives at our scale. That is itself the evidence supporting the bump
— the pre stack literally cannot run multi-node on B300, while the post
stack does.

Test plan

  • make all from slurm/ builds and imports successfully.
  • CI version-gate passes against the new Dockerfile.
  • pretrain-llama2.sbatch smoke run on pre image — passes 1n
    (~710 TFLOPS/GPU); multi-node hangs (LOAD_FAIL above).
  • pretrain-llama2.sbatch smoke run on post image, after both
    compatibility fixes — passes (~809 TFLOPS/GPU 1n; ~732–782 4n–32n).
  • pretrain-llama3-8b.sbatch smoke run on post image (1n).

KeitaW added 5 commits April 29, 2026 13:43
NGC pytorch:26.02-py3 ships CUDA 13 and a recent NCCL with native sm_103
binaries, which Blackwell Ultra (B300) needs to avoid the PTX-JIT slow
path. The ARG bumps line up with the repo's CI version gate
(EFA >= 1.47.0, NCCL >= 2.28, CUDA >= 13.0): EFA installer 1.43.2 ->
1.48.0, plus explicit NCCL_VERSION / AWS_OFI_NCCL_VERSION ARGs so the
gate's `grep nccl` finds compliant values even though the base image and
EFA installer provide them. GDRCopy v2.5.1 -> v2.5.2.

Megatron-LM is bumped from core_v0.12.1 to core_v0.17.0 (Apr 2026
release, latest core_v* tag); transformers 4.52.4 -> 4.57.6 (latest 4.x
patch — staying on 4.x to avoid the API breaks in 5.x).

Adds slurm/llama3/pretrain-llama3-8b.sbatch as a Llama 3 8B-specific
launcher. Defaults are tuned for 8x B300 per node: TP=1, PP=1, CP=2,
seq_len=8192, bf16, transformer_engine. Uses HuggingFaceTokenizer with
meta-llama/Meta-Llama-3-8B and rotary-base=500000 (the Llama 3 RoPE
base, distinct from Llama 2's 10000). Data-prep mirrors the existing
llama2 flow; pointer documented in slurm/llama3/README.md.
core_v0.17.0's data iterator builder dereferences args.eval_interval
even when --eval-iters 0, crashing with TypeError("unsupported operand
type(s) for //: 'int' and 'NoneType'") in
megatron/training/training.py:1143. Setting --eval-interval (any value;
1000 mirrors the gpt3 sbatch in this directory) is enough to satisfy
the new validator without changing behavior, since --eval-iters 0
disables eval anyway.
core_v0.17.0 lazy-builds the megatron.core.datasets.helpers_cpp C++ module
the first time a dataset is accessed. The build runs on rank 0 only, but
/workspace inside each Pyxis container is local — when training spans
multiple nodes, ranks on the other nodes never see the rank-0 build and
crash with:

    ModuleNotFoundError: No module named 'megatron.core.datasets.helpers_cpp'

Baking the .so into the image at build time makes it available to every
rank on every node from the start. Observed on post 4n/8n/16n runs against
the pre-fix image; passes on 1n because there's only one node.

Detect Python's include dir + pybind11 include + extension suffix from the
interpreter so this stays correct if the NGC base image's Python version
moves (currently 3.12 in nvcr.io/nvidia/pytorch:26.02-py3).
The llama3 sbatch had the same missing --eval-interval that crashed the
llama2 sbatch on core_v0.17.0:

    eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters
    TypeError: unsupported operand type(s) for //: 'int' and 'NoneType'

Adding the same harmless `--eval-interval 1000` mirrors the llama2 and
gpt3 sbatch defaults; --eval-iters 0 still disables eval at runtime.
@KeitaW KeitaW marked this pull request as ready for review May 8, 2026 10:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant