Skip to content

Commit dbcd5fd

Browse files
committed
style: ruff formatting
1 parent e771b57 commit dbcd5fd

2 files changed

Lines changed: 3 additions & 4 deletions

File tree

src/pruna/algorithms/token_merging.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,7 @@ def forward(
248248
attn_weights = attn_weights * head_mask
249249

250250
attn_weights = attn_weights.softmax(dim=-1)
251-
attn_probs = torch.nn.functional.dropout(
252-
attn_weights, p=self.dropout_prob if self.training else 0.0
253-
)
251+
attn_probs = torch.nn.functional.dropout(attn_weights, p=self.dropout_prob if self.training else 0.0)
254252

255253
context_layer = (attn_probs @ value_layer).transpose(1, 2)
256254
context_layer = context_layer.reshape(batch_size, -1, self.all_head_size)

src/pruna/engine/model_checks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def is_vit(model: Any) -> bool:
140140
-------
141141
bool
142142
True if the model is a ViT model, False otherwise.
143-
"""
143+
"""
144144
return model.__class__.__name__ == "ViTForImageClassification"
145145

146146

@@ -160,6 +160,7 @@ def is_transformers_pipeline_with_vit(model: Any) -> bool:
160160
"""
161161
return isinstance(model, ImageClassificationPipeline) and is_vit(getattr(model, "model", None))
162162

163+
163164
def is_transformers_pipeline_with_causal_lm(model: Any) -> bool:
164165
"""
165166
Check if the model is a transformers pipeline (for tasks like text generation, classification, etc.).

0 commit comments

Comments
 (0)