Skip to content

Commit 841559e

Browse files
yeyu-nvidiakevalmorabia97
authored andcommitted
remove duplicated RMSNorm and use LlamaRMSNorm from transformers (#774)
## What does this PR do? Code cleanup **Overview:** Remove RMSNorm which is identical to LlamaRMSNorm from transformers. ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Updated the normalization implementation in the Eagle speculative module. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Ye Yu <yeyu@nvidia.com>
1 parent 58a5f1e commit 841559e

2 files changed

Lines changed: 2 additions & 21 deletions

File tree

modelopt/torch/speculative/eagle/utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
"""Eagle model utils."""
3737

3838
import torch
39-
from torch import nn
4039

4140

4241
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
@@ -71,21 +70,3 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No
7170
inverted_mask = 1.0 - expanded_mask
7271

7372
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
74-
75-
76-
class RMSNorm(nn.Module):
77-
"""Borrowed from LlamaRMSNorm class."""
78-
79-
def __init__(self, hidden_size, eps=1e-6):
80-
"""LlamaRMSNorm is equivalent to T5LayerNorm."""
81-
super().__init__()
82-
self.weight = nn.Parameter(torch.ones(hidden_size))
83-
self.variance_epsilon = eps
84-
85-
def forward(self, hidden_states):
86-
"""Forward function for RMSNorm."""
87-
input_dtype = hidden_states.dtype
88-
hidden_states = hidden_states.to(torch.float32)
89-
variance = hidden_states.pow(2).mean(-1, keepdim=True)
90-
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
91-
return self.weight * hidden_states.to(input_dtype)

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
from ..eagle.conversion import EagleDMRegistry
5151
from ..eagle.eagle_model import EagleModel
52-
from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask
52+
from ..eagle.utils import expand_mask, make_causal_mask
5353
from ..medusa.conversion import MedusaDMRegistry
5454
from ..medusa.medusa_model import MedusaModel
5555
from ..utils import (
@@ -219,7 +219,7 @@ def __init__(self, config, decoder_layer_cls, bias=False):
219219
[decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
220220
)
221221
if config.use_last_layernorm:
222-
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
222+
self.norm = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps)
223223

224224
# Optionally, we use a smaller vocab table for eagle module
225225
if config.draft_vocab_size != config.vocab_size or config.has_lm_head:

0 commit comments

Comments
 (0)