diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index 8c7a2eb048f..2055fc65900 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -36,9 +36,12 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): if SYSTEM == "cuda": import dropout_layer_norm + major, _ = torch.cuda.get_device_capability() + is_blackwell = major > 9 + class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - if hidden_states.shape[-1] > 8192: + if hidden_states.shape[-1] > 8192 or is_blackwell: if residual is not None: hidden_states += residual residual = hidden_states @@ -142,7 +145,7 @@ def forward(self, hidden_states, residual=None): self.variance_epsilon, ) return out, residual - elif hidden_states.shape[-1] > 8192: + elif hidden_states.shape[-1] > 8192 or is_blackwell: if residual is not None: hidden_states += residual residual = hidden_states