@@ -161,7 +161,9 @@ def eager_attention_forward(
161161 causal_mask = attention_mask [:, :, :, : key_states .shape [- 2 ]]
162162 attn_weights = attn_weights + causal_mask
163163
164- attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = torch .float32 ).to (query .dtype )
164+ # Note we do not upcast the attention weights to float32 here, as it introduces
165+ # noise in the attention weights and is not necessary when using BF16
166+ attn_weights = nn .functional .softmax (attn_weights , dim = - 1 , dtype = query .dtype )
165167 attn_weights = nn .functional .dropout (attn_weights , p = dropout , training = module .training )
166168 attn_output = torch .matmul (attn_weights , value_states )
167169 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
@@ -479,14 +481,16 @@ def forward(
479481 )
480482
481483
482- class GPT2ForSequenceClassification (GenericForSequenceClassification , GPT2PreTrainedModel ): ...
484+ class GPT2ForSequenceClassification (GenericForSequenceClassification , GPT2PreTrainedModel ):
485+ ...
483486
484487
485488class GPT2ForQuestionAnswering (GenericForQuestionAnswering , GPT2PreTrainedModel ):
486489 base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
487490
488491
489- class GPT2ForTokenClassification (GenericForTokenClassification , GPT2PreTrainedModel ): ...
492+ class GPT2ForTokenClassification (GenericForTokenClassification , GPT2PreTrainedModel ):
493+ ...
490494
491495
492496__all__ = [
0 commit comments