Skip to content

Commit 53da3fd

Browse files
committed
fix: fixes eager conversion on GPU
1 parent 1d909b2 commit 53da3fd

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

src/modalities/conversion/gpt2/modeling_gpt2.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ def eager_attention_forward(
161161
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
162162
attn_weights = attn_weights + causal_mask
163163

164-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
164+
# Note we do not upcast the attention weights to float32 here, as it introduces
165+
# noise in the attention weights and is not necessary when using BF16
166+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=query.dtype)
165167
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
166168
attn_output = torch.matmul(attn_weights, value_states)
167169
attn_output = attn_output.transpose(1, 2).contiguous()
@@ -479,14 +481,16 @@ def forward(
479481
)
480482

481483

482-
class GPT2ForSequenceClassification(GenericForSequenceClassification, GPT2PreTrainedModel): ...
484+
class GPT2ForSequenceClassification(GenericForSequenceClassification, GPT2PreTrainedModel):
485+
...
483486

484487

485488
class GPT2ForQuestionAnswering(GenericForQuestionAnswering, GPT2PreTrainedModel):
486489
base_model_prefix = "transformer" # For BC, where `transformer` was used instead of `model`
487490

488491

489-
class GPT2ForTokenClassification(GenericForTokenClassification, GPT2PreTrainedModel): ...
492+
class GPT2ForTokenClassification(GenericForTokenClassification, GPT2PreTrainedModel):
493+
...
490494

491495

492496
__all__ = [

0 commit comments

Comments
 (0)