Skip to content

Commit 74ac48f

Browse files
committed
Attempt to address failing CI
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent 3b8ebc2 commit 74ac48f

3 files changed

Lines changed: 85 additions & 7 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919

2020
import math
21+
import sys
2122
from dataclasses import dataclass
2223
from functools import partial
24+
from pathlib import Path
2325
from typing import Callable, Iterable, Literal, Optional, Type
2426

2527
import torch
@@ -54,6 +56,30 @@
5456
from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond
5557

5658

59+
def _patch_megatron_dataset_helper_compile() -> None:
60+
"""Skip Megatron's runtime helper build when a wheel already ships the extension."""
61+
from megatron.core.datasets import utils as dataset_utils
62+
63+
original_compile_helpers = dataset_utils.compile_helpers
64+
if getattr(original_compile_helpers, "_evo2_prebuilt_helper_guard", False):
65+
guarded_compile_helpers = original_compile_helpers
66+
else:
67+
68+
def guarded_compile_helpers() -> None:
69+
datasets_dir = Path(dataset_utils.__file__).resolve().parent
70+
if not (datasets_dir / "Makefile").exists() and list(datasets_dir.glob("helpers_cpp*.so")):
71+
return None
72+
return original_compile_helpers()
73+
74+
guarded_compile_helpers._evo2_prebuilt_helper_guard = True
75+
dataset_utils.compile_helpers = guarded_compile_helpers
76+
77+
bridge_initialize = sys.modules.get("megatron.bridge.training.initialize")
78+
if bridge_initialize is not None:
79+
bridge_initialize.compile_helpers = guarded_compile_helpers
80+
81+
82+
_patch_megatron_dataset_helper_compile()
5783
register_allowed_target_prefix("bionemo.evo2.")
5884

5985

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/run/train.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,20 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
408408
help="Use the faster, but maybe less accurate fused form of cross entropy, "
409409
"which also has bf16 grads internally.",
410410
) # DONE
411-
parser.add_argument(
412-
"--no-fp32-residual-connection",
411+
fp32_residual_group = parser.add_mutually_exclusive_group(required=False)
412+
fp32_residual_group.add_argument(
413+
"--fp32-residual-connection",
414+
dest="fp32_residual_connection",
413415
action="store_true",
414-
default=False,
415-
help="If set, turn off fp32 residual connections which may be faster but may impact accuracy.",
416+
default=None,
417+
help="Enable fp32 residual connections. Defaults to the selected model provider setting.",
418+
)
419+
fp32_residual_group.add_argument(
420+
"--no-fp32-residual-connection",
421+
dest="fp32_residual_connection",
422+
action="store_false",
423+
default=None,
424+
help="Disable fp32 residual connections. Defaults to the selected model provider setting.",
416425
) # DONE
417426
parser.add_argument(
418427
"--debug-ddp-parity-freq",
@@ -859,11 +868,11 @@ def train(args: argparse.Namespace) -> None:
859868
cfg.model.seq_len_interpolation_factor = args.seq_len_interpolation_factor
860869
cfg.model.calculate_per_token_loss = not args.no_calculate_per_token_loss
861870
model_type = infer_model_type(args.model_size)
862-
if model_type != "hyena" and not args.no_fp32_residual_connection:
871+
if args.fp32_residual_connection is not None:
872+
cfg.model.fp32_residual_connection = args.fp32_residual_connection
873+
if model_type != "hyena" and cfg.model.fp32_residual_connection:
863874
logger.info("Disabling fp32_residual_connection for non-Hyena model (not compatible with TE layers)")
864875
cfg.model.fp32_residual_connection = False
865-
else:
866-
cfg.model.fp32_residual_connection = not args.no_fp32_residual_connection
867876
cfg.model.cross_entropy_loss_fusion = args.cross_entropy_loss_fusion
868877
# cfg.model.cuda_graph_impl = "local" # or "transformer_engine"
869878
# cfg.model.cuda_graph_scope = "full_iteration"

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_model_providers.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515

1616
"""Tests for model provider instantiation, naming, and checkpoint converters."""
1717

18+
from pathlib import Path
19+
1820
import pytest
1921
import torch
2022

2123
from bionemo.evo2.models.evo2_provider import (
2224
HYENA_MODEL_OPTIONS,
2325
MODEL_OPTIONS,
2426
Hyena1bModelProvider,
27+
_patch_megatron_dataset_helper_compile,
2528
infer_model_type,
2629
)
2730
from bionemo.evo2.utils.checkpoint.mbridge_to_vortex import _split_fc1, mbridge_to_vortex_state_dict
@@ -63,6 +66,46 @@ def test_infer_model_type_unknown():
6366
infer_model_type("nonexistent_model")
6467

6568

69+
@pytest.mark.parametrize(
70+
("has_makefile", "has_prebuilt_extension", "expected_original_calls"),
71+
[
72+
(False, True, 0),
73+
(True, True, 1),
74+
(False, False, 1),
75+
],
76+
)
77+
def test_megatron_dataset_helper_compile_guard(
78+
monkeypatch: pytest.MonkeyPatch,
79+
tmp_path: Path,
80+
has_makefile: bool,
81+
has_prebuilt_extension: bool,
82+
expected_original_calls: int,
83+
):
84+
"""Skip Megatron's runtime make step only when a prebuilt helper extension exists."""
85+
from megatron.bridge.training import initialize as bridge_initialize
86+
from megatron.core.datasets import utils as dataset_utils
87+
88+
calls = []
89+
90+
def original_compile_helpers():
91+
calls.append("called")
92+
93+
if has_makefile:
94+
(tmp_path / "Makefile").write_text("all:\n")
95+
if has_prebuilt_extension:
96+
(tmp_path / "helpers_cpp.cpython-312-x86_64-linux-gnu.so").touch()
97+
98+
monkeypatch.setattr(dataset_utils, "__file__", str(tmp_path / "utils.py"))
99+
monkeypatch.setattr(dataset_utils, "compile_helpers", original_compile_helpers)
100+
monkeypatch.setattr(bridge_initialize, "compile_helpers", original_compile_helpers)
101+
102+
_patch_megatron_dataset_helper_compile()
103+
104+
dataset_utils.compile_helpers()
105+
assert bridge_initialize.compile_helpers is dataset_utils.compile_helpers
106+
assert len(calls) == expected_original_calls
107+
108+
66109
def _make_mock_savanna_sd(pattern: str) -> dict[str, torch.Tensor]:
67110
"""Create a minimal mock savanna state dict for the given pattern.
68111

0 commit comments

Comments
 (0)