@@ -81,16 +81,13 @@ def evo2_gpt_forward_step(model, batch) -> torch.Tensor:
8181
8282
8383class Evo2GPTModel (GPTModel ):
84- """Mamba model that extends GPTModel for integration with NeMo.
85-
86- Note that the loss calculation is handled by CustomMCoreMambaModel instead.
87- """
84+ """GPT model that extends GPTModel for integration with NeMo."""
8885
8986 @override
9087 def get_inference_wrapper (
9188 self , params_dtype , inference_batch_times_seqlen_threshold , inference_max_seq_length = 8192
9289 ) -> GPTInferenceWrapper :
93- """Gets the inference wrapper for the Mamba model."""
90+ """Gets the inference wrapper for the GPT model."""
9491 model = self
9592 while model is not None :
9693 if getattr (model , "module" , None ) is not None :
@@ -133,7 +130,7 @@ def forward(
133130 runtime_gather_output : bool | None = None ,
134131 loss_mask : torch .Tensor | None = None ,
135132 ) -> torch .Tensor :
136- """Forward pass that delegates to CustomMCoreMambaModel , which handles loss calculation."""
133+ """Forward pass that delegates to GPTModel , which handles loss calculation."""
137134 extra_kwargs = {"packed_seq_params" : packed_seq_params } if packed_seq_params is not None else {}
138135 output_tensor = self .module (
139136 input_ids ,
@@ -147,21 +144,17 @@ def forward(
147144 loss_mask = loss_mask , # Pass loss_mask to the Megatron module
148145 ** extra_kwargs ,
149146 )
150-
151- # Return whatever CustomMCoreMambaModel.forward returns
152- # (logits during inference, loss during training)
153147 return output_tensor
154148
155149
156- # Custom MCoreMambaModel with reweighted loss calculation
157150class Evo2StyleMCoreGPTModel (megatron .core .models .gpt .gpt_model .GPTModel ):
158- """Custom version of MCoreMambaModel that implements reweighted loss calculation.
151+ """Custom version of GPTModel that implements reweighted loss calculation.
159152
160153 Note that this is similar to the HyenaModel for uppercase/lowercase handling.
161154 """
162155
163156 def __init__ (self , * args , ** kwargs ):
164- """Initializes `Evo2StyleMCoreMambaModel ` with unique parameters for the Evo2 variant of `MCoreMambaModel `."""
157+ """Initializes `Evo2StyleMCoreGPTModel ` with unique parameters for the Evo2 variant of `GPTModel `."""
165158 super ().__init__ (* args , ** kwargs )
166159 if self .config .use_targeted_variance_loss :
167160 if not hasattr (self .config , "embedding_init_method_std" ):
@@ -205,9 +198,9 @@ def forward(self, *args, labels: torch.Tensor | None = None, loss_mask: torch.Te
205198
206199
207200def gpt_no_weight_decay_cond (name , param , exclude_embeddings : bool = False ):
208- """Condition for no weight decay for Mamba parameters.
201+ """Condition for no weight decay for GPT parameters.
209202
210- Note that this follows the same pattern as in the original Mamba implementation.
203+ Note that this follows the same pattern as in the original GPT implementation.
211204 """
212205 # Mamba-specific parameters that should not have weight decay
213206 if ("embedding" in name and exclude_embeddings ) or getattr (param , "_no_weight_decay" , False ):
@@ -222,16 +215,16 @@ def gpt_no_weight_decay_cond(name, param, exclude_embeddings: bool = False):
222215
223216
224217def gpt_no_weight_decay_cond_with_embeddings (name , param ):
225- """Condition for no weight decay for Mamba parameters with embeddings.
218+ """Condition for no weight decay for GPT parameters with embeddings.
226219
227- Note that this follows the same pattern as in the original Mamba implementation but also skips WD on embeddings.
220+ Note that this follows the same pattern as in the original GPT implementation but also skips WD on embeddings.
228221 """
229222 return gpt_no_weight_decay_cond (name , param , exclude_embeddings = True )
230223
231224
232225@dataclass
233226class LLama31ConfigEvoLoss3B (llm .Llama3Config8B ):
234- """Config for 8B hybrid Mamba model."""
227+ """Config for 8B hybrid GPT model."""
235228
236229 # RoPE/context length related block:
237230 rotary_base : int = 500_000
0 commit comments