Skip to content

Commit 5778321

Browse files
authored
Refactor the way that we do weight decay skipping for hyena to follow ToT mbridge. (#1429)
### Description * No change to user level API but under the hood make use of the new mbridge API for definining custom weight decay skips. Depends on NVIDIA-NeMo/Megatron-Bridge#2010 * Update to tokenizer to support the new mbridge API for tokenizer init that no longer requires a path for path object for based inputs. Path objects no longer work with megatron using this path, so switching to strings in the recipe. * Remove unused nemo2 code/files that were left over in the refactor. ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [x] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. #### Triggering Code Rabbit AI Review To trigger a code review from code rabbit, comment on a pull request with one of these commands: - @coderabbitai review - Triggers a standard review - @coderabbitai full review - Triggers a comprehensive review See https://docs.coderabbit.ai/reference/review-commands for a full list of commands. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added `no_weight_decay_embeddings` configuration parameter for Evo2 training recipes to control embedding weight decay behavior. * **Chores** * Updated Megatron-related dependency versions. * **Tests** * Improved test fixture scoping for better test isolation. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent cca7577 commit 5778321

17 files changed

Lines changed: 160 additions & 1224 deletions

File tree

bionemo-recipes/recipes/evo2_megatron/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ bionemo-core = { git = "https://github.com/NVIDIA/bionemo-framework.git", branch
9595
nvidia-resiliency-ext = { git = "https://github.com/NVIDIA/nvidia-resiliency-ext.git", rev = "54f85fe422d296cf04ea524130014bd3a2c3add1" } # pragma: allowlist secret
9696

9797
# Megatron Bundle. This points to a version that still supports the deprecated no_weight_decay_cond field until the API for an alternative has been finalized.
98-
megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "18ef1b61309dd45bc0535fb7c60064b9d8829a35" } # pragma: allowlist secret
99-
megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "18ef1b61309dd45bc0535fb7c60064b9d8829a35", subdirectory = "3rdparty/Megatron-LM" } # pragma: allowlist secret
98+
megatron-bridge = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f" } # pragma: allowlist secret
99+
megatron-core = { git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge.git", rev = "549e3cb970c170b1d7a86d021261efe05e8a5d9f", subdirectory = "3rdparty/Megatron-LM" } # pragma: allowlist secret
100100

101101
[tool.uv.extra-build-dependencies]
102102
warp-lang = ["wheel_stub"]

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/data/dataset_tokenizer.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828

2929
REPO_BASE_DIR = Path(__file__).parent.parent.parent.parent.parent
30-
DEFAULT_HF_TOKENIZER_MODEL_PATH = REPO_BASE_DIR / "tokenizers" / "nucleotide_fast_tokenizer_256"
31-
DEFAULT_HF_TOKENIZER_MODEL_PATH_512 = REPO_BASE_DIR / "tokenizers" / "nucleotide_fast_tokenizer_512"
30+
DEFAULT_HF_TOKENIZER_MODEL_PATH = str(REPO_BASE_DIR / "tokenizers" / "nucleotide_fast_tokenizer_256")
31+
DEFAULT_HF_TOKENIZER_MODEL_PATH_512 = str(REPO_BASE_DIR / "tokenizers" / "nucleotide_fast_tokenizer_512")
3232

3333

3434
class Evo2DatasetTokenizer:
@@ -39,18 +39,18 @@ def __init__(self, params: Evo2PreprocessingConfig | None = None):
3939
# Pass all NeMo2/Megatron-compliant parameters associated with config.Evo2PreprocessingConfig.
4040
self.params: Evo2PreprocessingConfig = params if params is not None else Evo2PreprocessingConfig()
4141
if self.params.hf_tokenizer_model_path is not None:
42-
hf_tokenizer_model_or_path = Path(self.params.hf_tokenizer_model_path)
43-
hf_tokenizer_desc: str = hf_tokenizer_model_or_path.name
44-
assert hf_tokenizer_model_or_path.exists(), (
42+
hf_tokenizer_model_or_path = str(self.params.hf_tokenizer_model_path)
43+
hf_tokenizer_desc: str = Path(hf_tokenizer_model_or_path).name
44+
assert Path(hf_tokenizer_model_or_path).exists(), (
4545
f"Hugging Face tokenizer model path {hf_tokenizer_model_or_path} does not exist."
4646
)
4747
elif self.params.hf_tokenizer_model_name is not None:
4848
hf_tokenizer_model_or_path = str(self.params.hf_tokenizer_model_name)
4949
hf_tokenizer_desc = hf_tokenizer_model_or_path.replace("/", "--").replace(":", "--")
5050
else:
5151
hf_tokenizer_model_or_path = DEFAULT_HF_TOKENIZER_MODEL_PATH
52-
hf_tokenizer_desc = hf_tokenizer_model_or_path.name
53-
assert hf_tokenizer_model_or_path.exists(), (
52+
hf_tokenizer_desc = Path(hf_tokenizer_model_or_path).name
53+
assert Path(hf_tokenizer_model_or_path).exists(), (
5454
f"Default Hugging Face tokenizer model path {hf_tokenizer_model_or_path} does not exist."
5555
)
5656
self.hf_tokenizer_desc = hf_tokenizer_desc
@@ -81,7 +81,12 @@ def tokenize(
8181
else:
8282
t_fixed = t
8383
# Tokenize the string.
84-
text_ids: list = self.tokenizer.text_to_ids(t_fixed)
84+
if hasattr(self.tokenizer, "text_to_ids"):
85+
# Handle the legacy NeMo2 style tokenizer.
86+
text_ids: list = self.tokenizer.text_to_ids(t_fixed)
87+
else:
88+
# Handle the new Megatron-Bridge style tokenizer.
89+
text_ids: list = self.tokenizer.tokenize(t_fixed)
8590
if drop_empty_sequences and len(text_ids) == 0:
8691
continue
8792
# Append EOD token (EOD ID: 0) if appropriate.

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/data/fasta_dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ def __len__(self):
5656
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
5757
"""Get an item from the dataset."""
5858
sequence = self.fasta[self.seqids[idx]].sequence().upper()
59-
tokenized_seq = self.tokenizer.text_to_ids(sequence)
59+
if hasattr(self.tokenizer, "tokenize"):
60+
# Handle the new Megatron-Bridge style tokenizer.
61+
tokenized_seq = self.tokenizer.tokenize(sequence)
62+
else:
63+
# Handle the legacy NeMo2 style tokenizer.
64+
tokenized_seq = self.tokenizer.text_to_ids(sequence)
6065
if self.prepend_bos: # in pretraining we use EOS to start new sequences.
6166
tokens: list[int] = [self.tokenizer.eod, *tokenized_seq]
6267
else:

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/data/sharded_eden_dataset_provider.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,13 @@ def _prepare_control_tags(self):
332332
for seq_id in unique_sequence_ids:
333333
# Extract meaningful part from sequence ID for control tag
334334
ctrl_name = seq_id.split("__")[0] if "__" in seq_id else seq_id
335-
self.ctrl_ids_map[seq_id] = self.tokenizer.text_to_ids(f"<ctrl_{ctrl_name.lower()}>")
335+
if hasattr(self.tokenizer, "tokenize"):
336+
# Handle the new Megatron-Bridge style tokenizer.
337+
ctrl_ids = self.tokenizer.tokenize(f"<ctrl_{ctrl_name.lower()}>")
338+
else:
339+
# Handle the legacy NeMo2 style tokenizer.
340+
ctrl_ids = self.tokenizer.text_to_ids(f"<ctrl_{ctrl_name.lower()}>")
341+
self.ctrl_ids_map[seq_id] = ctrl_ids
336342

337343
def __len__(self) -> int:
338344
"""Return the length of the dataset."""
@@ -455,7 +461,12 @@ def __getitem__(self, idx: np.int64) -> Dict[str, torch.Tensor]:
455461
seq = self.reverse_complement(seq)
456462

457463
# Tokenize
458-
token_ids = header + self.tokenizer.text_to_ids(seq) + footer
464+
if hasattr(self.tokenizer, "tokenize"):
465+
# Handle the new Megatron-Bridge style tokenizer.
466+
token_ids = header + self.tokenizer.tokenize(seq) + footer
467+
else:
468+
# Handle the legacy NeMo2 style tokenizer.
469+
token_ids = header + self.tokenizer.text_to_ids(seq) + footer
459470

460471
# Pad/trim
461472
if len(token_ids) < self.seq_length:
@@ -516,7 +527,10 @@ def sep_id(self) -> int:
516527
"""Get the separator token ID."""
517528
sep_id = getattr(self.tokenizer, "_sep_id", None)
518529
if sep_id is None:
519-
sep_id = self.tokenizer.text_to_ids("<SEP>")
530+
if hasattr(self.tokenizer, "tokenize"):
531+
sep_id = self.tokenizer.tokenize("<SEP>")
532+
else:
533+
sep_id = self.tokenizer.text_to_ids("<SEP>")
520534
if len(sep_id) == 1:
521535
sep_id = sep_id[0]
522536
else:
@@ -530,7 +544,12 @@ def pad_id(self) -> int:
530544
"""Get the padding token ID."""
531545
pad_id = getattr(self.tokenizer, "pad_id", None)
532546
if pad_id is None:
533-
pad_id = self.tokenizer.text_to_ids("<PAD>")
547+
if hasattr(self.tokenizer, "tokenize"):
548+
# Handle the new Megatron-Bridge style tokenizer.
549+
pad_id = self.tokenizer.tokenize("<PAD>")
550+
else:
551+
# Handle the legacy NeMo2 style tokenizer.
552+
pad_id = self.tokenizer.text_to_ids("<PAD>")
534553
if len(pad_id) == 1:
535554
pad_id = pad_id[0]
536555
else:

bionemo-recipes/recipes/evo2_megatron/src/bionemo/evo2/models/evo2_provider.py

Lines changed: 51 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
import torch
2626
from megatron.bridge.models.model_provider import ModelProviderMixin
2727
from megatron.bridge.models.transformer_config import TransformerConfig
28-
from megatron.bridge.training.config import ConfigContainer
28+
from megatron.bridge.training.config import (
29+
ConfigContainer,
30+
OptimizerConfigOverrideProvider,
31+
OptimizerConfigOverrideProviderContext,
32+
)
2933
from megatron.bridge.training.gpt_step import get_batch_from_iterator
3034
from megatron.bridge.training.losses import masked_next_token_loss
3135
from megatron.bridge.training.state import GlobalState
@@ -34,24 +38,21 @@
3438
from megatron.bridge.utils.vocab_utils import calculate_padded_vocab_size
3539
from megatron.core import parallel_state
3640
from megatron.core.inference.contexts import StaticInferenceContext
41+
from megatron.core.optimizer import (
42+
ParamGroupOverride,
43+
ParamKey,
44+
ParamPredicate,
45+
)
3746
from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage
3847
from megatron.core.transformer.enums import AttnBackend
3948
from megatron.core.utils import get_batch_on_this_cp_rank, get_model_config
4049

4150
from bionemo.evo2.models.megatron.hyena.hyena_config import HyenaConfig as _HyenaConfigForFlops
42-
43-
# from nemo.collections.llm.gpt.model.base import GPTModel, gpt_data_step # FIXME do megatron bridge thing instead of this
4451
from bionemo.evo2.models.megatron.hyena.hyena_layer_specs import get_hyena_stack_spec
4552
from bionemo.evo2.models.megatron.hyena.hyena_model import HyenaModel as MCoreHyenaModel
4653
from bionemo.evo2.models.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond
4754

4855

49-
# from nemo.lightning import get_vocab_size, io, teardown
50-
# from nemo.lightning.base import NEMO_MODELS_CACHE
51-
# from nemo.lightning.io.state import TransformFns
52-
# from nemo.utils import logging
53-
54-
5556
def get_vocab_size(*args, **kwargs):
5657
raise NotImplementedError("FIXME get_vocab_size is not implemented Find it in megatron bridge")
5758

@@ -60,7 +61,47 @@ def gpt_data_step(*args, **kwargs):
6061
raise NotImplementedError("FIXME gpt_data_step is not implemented Find it in megatron bridge")
6162

6263

63-
# FIXME convert the nemo style configs to megatron bridge style configs
64+
@dataclass
65+
class HyenaOptimizerConfigOverrideProvider(OptimizerConfigOverrideProvider):
66+
"""Hyena-specific optimizer config override provider."""
67+
68+
no_weight_decay_embeddings: bool = False
69+
70+
def build_config_overrides(
71+
self, context: OptimizerConfigOverrideProviderContext
72+
) -> dict[ParamKey, ParamGroupOverride] | None:
73+
"""Build config overrides for weight decay based on scheduler configuration.
74+
75+
This function creates parameter-specific overrides for weight decay behavior.
76+
By default, weight decay is skipped for bias parameters and 1D parameters.
77+
For Qwen3-Next models, weight decay is applied to q_layernorm and k_layernorm.
78+
"""
79+
optimizer_config = context.optimizer_config
80+
config_overrides: dict[ParamKey, ParamGroupOverride] = {}
81+
param_length_1_match = ParamPredicate(name="param_len_1", fn=lambda param: len(param.shape) == 1)
82+
name_tuple: tuple[str, ...] = (
83+
"*.bias",
84+
"*.filter.p",
85+
"*.filter.R",
86+
"*.filter.gamma",
87+
"*.short_conv.short_conv_weight",
88+
)
89+
if self.no_weight_decay_embeddings:
90+
name_tuple += ("*embedding*",)
91+
param_wd_mult_key = ParamKey(
92+
name=name_tuple, # type: ignore
93+
predicate=param_length_1_match,
94+
)
95+
96+
config_overrides[param_wd_mult_key] = ParamGroupOverride(wd_mult=0.0) # type: ignore
97+
98+
if optimizer_config.decoupled_lr is not None:
99+
decoupled_lr_config: ParamGroupOverride = {"max_lr": optimizer_config.decoupled_lr}
100+
decoupled_param_key = ParamKey(attr="is_embedding_or_output_parameter")
101+
if optimizer_config.decoupled_min_lr is not None:
102+
decoupled_lr_config["min_lr"] = optimizer_config.decoupled_min_lr
103+
config_overrides[decoupled_param_key] = decoupled_lr_config
104+
return config_overrides
64105

65106

66107
class HyenaInferenceContext(StaticInferenceContext):
@@ -75,103 +116,6 @@ def reset(self):
75116
delattr(self, key)
76117

77118

78-
# FIXME convert this to the megatron bridge style config for inference.
79-
# class HyenaModel(GPTModel):
80-
# """This is a wrapper around the MCoreHyenaModel to allow for inference.
81-
82-
# Our model follows the same API as the GPTModel, but the megatron model class is different so we need to handle the inference wrapper slightly differently.
83-
# """
84-
85-
# def get_inference_wrapper(
86-
# self, params_dtype, inference_batch_times_seqlen_threshold, inference_max_seq_length=None
87-
# ) -> torch.Tensor:
88-
# """Gets the inference wrapper for the Hyena model.
89-
90-
# Args:
91-
# params_dtype: The data type for model parameters
92-
# inference_batch_times_seqlen_threshold: Threshold for batch size * sequence length during inference
93-
# inference_max_seq_length: Maximum sequence length for inference
94-
95-
# Returns:
96-
# GPTInferenceWrapper: The inference wrapper for the model
97-
98-
# Raises:
99-
# ValueError: If MCoreHyenaModel instance not found or vocab size cannot be determined
100-
# """
101-
# # This is to get the MCore model required in GPTInferenceWrapper.
102-
# mcore_model = self.module
103-
# while mcore_model:
104-
# if type(mcore_model) is MCoreHyenaModel:
105-
# break
106-
# mcore_model = getattr(mcore_model, "module", None)
107-
# if mcore_model is None or type(mcore_model) is not MCoreHyenaModel:
108-
# raise ValueError("Exact MCoreHyenaModel instance not found in the model structure.")
109-
110-
# vocab_size = None
111-
# if self.tokenizer is not None:
112-
# vocab_size = self.tokenizer.vocab_size
113-
# elif hasattr(self.config, "vocab_size"):
114-
# vocab_size = self.config.vocab_size
115-
# else:
116-
# raise ValueError(
117-
# "Unable to find vocab size."
118-
# " Either pass in a tokenizer with vocab size, or set vocab size in the model config"
119-
# )
120-
121-
# inference_wrapper_config = InferenceWrapperConfig(
122-
# hidden_size=mcore_model.config.hidden_size,
123-
# params_dtype=params_dtype,
124-
# inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
125-
# padded_vocab_size=vocab_size,
126-
# inference_max_seq_length=inference_max_seq_length,
127-
# inference_max_requests=1,
128-
# )
129-
130-
# inference_context = HyenaInferenceContext.from_config(inference_wrapper_config)
131-
# model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config, inference_context)
132-
# return model_inference_wrapper
133-
134-
# def forward(
135-
# self,
136-
# input_ids: torch.Tensor,
137-
# position_ids: torch.Tensor,
138-
# attention_mask: Optional[torch.Tensor] = None,
139-
# labels: Optional[torch.Tensor] = None,
140-
# decoder_input: Optional[torch.Tensor] = None,
141-
# loss_mask: Optional[torch.Tensor] = None,
142-
# inference_context=None,
143-
# packed_seq_params=None,
144-
# ) -> torch.Tensor:
145-
# """Forward pass of the Hyena model.
146-
147-
# Args:
148-
# input_ids: Input token IDs
149-
# position_ids: Position IDs for input tokens
150-
# attention_mask: Optional attention mask
151-
# labels: Optional labels for loss computation
152-
# decoder_input: Optional decoder input
153-
# loss_mask: Optional loss mask
154-
# inference_context: Optional inference parameters
155-
# packed_seq_params: Optional parameters for packed sequences
156-
157-
158-
# Returns:
159-
# torch.Tensor: Output tensor from the model
160-
# """
161-
# extra_kwargs = {"packed_seq_params": packed_seq_params} if packed_seq_params is not None else {}
162-
# output_tensor = self.module(
163-
# input_ids,
164-
# position_ids,
165-
# attention_mask,
166-
# decoder_input=decoder_input,
167-
# labels=labels,
168-
# inference_context=inference_context,
169-
# loss_mask=loss_mask,
170-
# **extra_kwargs,
171-
# )
172-
# return output_tensor
173-
174-
175119
def get_batch(
176120
data_iterator: Iterable, cfg: ConfigContainer, use_mtp: bool = False, *, pg_collection
177121
) -> tuple[
@@ -329,7 +273,6 @@ def _create_loss_function(loss_mask: torch.Tensor, check_for_nan_in_loss: bool,
329273
)
330274

331275

332-
# FIXME make sure these conform to megatron/megatron bridge style.
333276
@dataclass
334277
class HyenaModelProvider(TransformerConfig, ModelProviderMixin[MCoreHyenaModel]):
335278
"""Configuration dataclass for Hyena.

0 commit comments

Comments
 (0)