Skip to content

Commit 47f3022

Browse files
committed
fix docs
Signed-off-by: Yang Zhang <yangzhang@nvidia.com>
1 parent 52dacd6 commit 47f3022

1 file changed

Lines changed: 10 additions & 17 deletions

File tree

  • sub-packages/bionemo-evo2/src/bionemo/evo2/models

sub-packages/bionemo-evo2/src/bionemo/evo2/models/gpt.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,13 @@ def evo2_gpt_forward_step(model, batch) -> torch.Tensor:
8181

8282

8383
class Evo2GPTModel(GPTModel):
84-
"""Mamba model that extends GPTModel for integration with NeMo.
85-
86-
Note that the loss calculation is handled by CustomMCoreMambaModel instead.
87-
"""
84+
"""GPT model that extends GPTModel for integration with NeMo."""
8885

8986
@override
9087
def get_inference_wrapper(
9188
self, params_dtype, inference_batch_times_seqlen_threshold, inference_max_seq_length=8192
9289
) -> GPTInferenceWrapper:
93-
"""Gets the inference wrapper for the Mamba model."""
90+
"""Gets the inference wrapper for the GPT model."""
9491
model = self
9592
while model is not None:
9693
if getattr(model, "module", None) is not None:
@@ -133,7 +130,7 @@ def forward(
133130
runtime_gather_output: bool | None = None,
134131
loss_mask: torch.Tensor | None = None,
135132
) -> torch.Tensor:
136-
"""Forward pass that delegates to CustomMCoreMambaModel, which handles loss calculation."""
133+
"""Forward pass that delegates to GPTModel, which handles loss calculation."""
137134
extra_kwargs = {"packed_seq_params": packed_seq_params} if packed_seq_params is not None else {}
138135
output_tensor = self.module(
139136
input_ids,
@@ -147,21 +144,17 @@ def forward(
147144
loss_mask=loss_mask, # Pass loss_mask to the Megatron module
148145
**extra_kwargs,
149146
)
150-
151-
# Return whatever CustomMCoreMambaModel.forward returns
152-
# (logits during inference, loss during training)
153147
return output_tensor
154148

155149

156-
# Custom MCoreMambaModel with reweighted loss calculation
157150
class Evo2StyleMCoreGPTModel(megatron.core.models.gpt.gpt_model.GPTModel):
158-
"""Custom version of MCoreMambaModel that implements reweighted loss calculation.
151+
"""Custom version of GPTModel that implements reweighted loss calculation.
159152
160153
Note that this is similar to the HyenaModel for uppercase/lowercase handling.
161154
"""
162155

163156
def __init__(self, *args, **kwargs):
164-
"""Initializes `Evo2StyleMCoreMambaModel` with unique parameters for the Evo2 variant of `MCoreMambaModel`."""
157+
"""Initializes `Evo2StyleMCoreGPTModel` with unique parameters for the Evo2 variant of `GPTModel`."""
165158
super().__init__(*args, **kwargs)
166159
if self.config.use_targeted_variance_loss:
167160
if not hasattr(self.config, "embedding_init_method_std"):
@@ -205,9 +198,9 @@ def forward(self, *args, labels: torch.Tensor | None = None, loss_mask: torch.Te
205198

206199

207200
def gpt_no_weight_decay_cond(name, param, exclude_embeddings: bool = False):
208-
"""Condition for no weight decay for Mamba parameters.
201+
"""Condition for no weight decay for GPT parameters.
209202
210-
Note that this follows the same pattern as in the original Mamba implementation.
203+
Note that this follows the same pattern as in the original GPT implementation.
211204
"""
212205
# Mamba-specific parameters that should not have weight decay
213206
if ("embedding" in name and exclude_embeddings) or getattr(param, "_no_weight_decay", False):
@@ -222,16 +215,16 @@ def gpt_no_weight_decay_cond(name, param, exclude_embeddings: bool = False):
222215

223216

224217
def gpt_no_weight_decay_cond_with_embeddings(name, param):
225-
"""Condition for no weight decay for Mamba parameters with embeddings.
218+
"""Condition for no weight decay for GPT parameters with embeddings.
226219
227-
Note that this follows the same pattern as in the original Mamba implementation but also skips WD on embeddings.
220+
Note that this follows the same pattern as in the original GPT implementation but also skips WD on embeddings.
228221
"""
229222
return gpt_no_weight_decay_cond(name, param, exclude_embeddings=True)
230223

231224

232225
@dataclass
233226
class LLama31ConfigEvoLoss3B(llm.Llama3Config8B):
234-
"""Config for 8B hybrid Mamba model."""
227+
"""Config for 8B hybrid GPT model."""
235228

236229
# RoPE/context length related block:
237230
rotary_base: int = 500_000

0 commit comments

Comments
 (0)