diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index 44b097ac8e2..7fef9389c0a 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -175,6 +175,9 @@ def _clip_iqa_get_anchor_vectors( anchors = model.get_text_features( text_processed["input_ids"].to(device), text_processed["attention_mask"].to(device) ) + # Handle both tensor and BaseModelOutputWithPooling returns (transformers v5) + if hasattr(anchors, "pooler_output"): + anchors = anchors.pooler_output return anchors / anchors.norm(p=2, dim=-1, keepdim=True) @@ -198,6 +201,9 @@ def _clip_iqa_update( else: processed_input = processor(images=[i.cpu() for i in images], return_tensors="pt", padding=True) img_features = model.get_image_features(processed_input["pixel_values"].to(device)) + # Handle both tensor and BaseModelOutputWithPooling returns (transformers v5) + if hasattr(img_features, "pooler_output"): + img_features = img_features.pooler_output return img_features / img_features.norm(p=2, dim=-1, keepdim=True) diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index e7a8e82669a..9bae2c6be7d 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -130,8 +130,8 @@ def _get_features( if modality == "image": image_data = [i for i in data if isinstance(i, Tensor)] # Add type checking for images processed = processor(images=[i.cpu() for i in image_data], return_tensors="pt", padding=True) - return model.get_image_features(processed["pixel_values"].to(device)) - if modality == "text": + features = model.get_image_features(processed["pixel_values"].to(device)) + elif modality == "text": processed = processor(text=data, return_tensors="pt", padding=True) if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "max_position_embeddings"): max_position_embeddings = model.config.text_config.max_position_embeddings @@ -144,8 +144,13 @@ def _get_features( ) processed["attention_mask"] = processed["attention_mask"][..., :max_position_embeddings] processed["input_ids"] = processed["input_ids"][..., :max_position_embeddings] - return model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device)) - raise ValueError(f"invalid modality {modality}") + features = model.get_text_features(processed["input_ids"].to(device), processed["attention_mask"].to(device)) + else: + raise ValueError(f"invalid modality {modality}") + # Handle both tensor and BaseModelOutputWithPooling returns (transformers v5) + if hasattr(features, "pooler_output"): + features = features.pooler_output + return features def _clip_score_update(