diff --git a/docs/source/links.rst b/docs/source/links.rst index 539d2728e74..0bdde6a0a91 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -50,6 +50,8 @@ .. _Mean Reciprocal Rank: https://en.wikipedia.org/wiki/Mean_reciprocal_rank .. _BERT_score: https://github.com/Tiiiger/bert_score/blob/master/bert_score/utils.py .. _Bert_score Evaluating Text Generation: https://arxiv.org/abs/1904.09675 +.. _DepthScore Evaluating Text Generation: https://arxiv.org/abs/2103.12711 +.. _DEPTH_score: https://github.com/PierreColombo/nlg_eval_via_simi_measures/blob/main/nlg_eval_via_simi_measures/depth_score.py .. _BLEU score: https://en.wikipedia.org/wiki/BLEU .. _BLEU: https://www.semanticscholar.org/paper/Bleu%3A-a-Method-for-Automatic-Evaluation-of-Machine-Papineni-Roukos/d7da009f457917aa381619facfa5ffae9329a6e9 .. _SacreBLEU: https://github.com/mjpost/sacrebleu diff --git a/docs/source/text/depth_score.rst b/docs/source/text/depth_score.rst new file mode 100644 index 00000000000..f73665e2886 --- /dev/null +++ b/docs/source/text/depth_score.rst @@ -0,0 +1,21 @@ +.. customcarditem:: + :header: Depth Score + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/summarization.svg + :tags: Text + +.. include:: ../links.rst + +########### +Depth Score +########### + +Module Interface +________________ + +.. autoclass:: torchmetrics.text.depth_score.DepthScore + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.text.depth_score.depth_score diff --git a/requirements/text.txt b/requirements/text.txt index 188e4dc1e77..672650c99c6 100644 --- a/requirements/text.txt +++ b/requirements/text.txt @@ -9,3 +9,7 @@ transformers >=4.43.0,<4.57 mecab-python3 >=1.0.6, <1.1.0 ipadic >=1.0.0, <1.1.0 sentencepiece >=0.2.0, <0.3.0 + +scikit-learn >1.5.0, <1.8.0 +POT >=0.9.0, <=0.9.6 +geomloss ==0.2.6 # strict diff --git a/src/torchmetrics/functional/text/__init__.py b/src/torchmetrics/functional/text/__init__.py index 9282be6fbae..acc84a0373d 100644 --- a/src/torchmetrics/functional/text/__init__.py +++ b/src/torchmetrics/functional/text/__init__.py @@ -48,6 +48,7 @@ if _TRANSFORMERS_GREATER_EQUAL_4_4: from torchmetrics.functional.text.bert import bert_score + from torchmetrics.functional.text.depth_score import depth_score from torchmetrics.functional.text.infolm import infolm - __all__ += ["bert_score", "infolm"] + __all__ += ["bert_score", "depth_score", "infolm"] diff --git a/src/torchmetrics/functional/text/depth_score.py b/src/torchmetrics/functional/text/depth_score.py new file mode 100644 index 00000000000..f27055bc46f --- /dev/null +++ b/src/torchmetrics/functional/text/depth_score.py @@ -0,0 +1,721 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from collections.abc import Iterator, Sequence +from contextlib import contextmanager +from typing import Any, Callable, List, Optional, Tuple, Union, cast + +import numpy as np +import torch +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader + +# TorchMetrics text helpers (same style as BERTScore) +from torchmetrics.functional.text.helper_embedding_metric import ( + TextDataset, + TokenizedDataset, + _check_shape_of_model_output, + _get_progress_bar, + _input_data_collator, + _output_data_collator, +) +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout +from torchmetrics.utilities.imports import ( + _GEOMLOSS_AVAILABLE, + _POT_AVAILABLE, + _SKLEARN_AVAILABLE, + _TQDM_AVAILABLE, + _TRANSFORMERS_GREATER_EQUAL_4_4, +) + +log = logging.getLogger(__name__) + + +@contextmanager +def _ignore_transformers_finetune_warning() -> Iterator[None]: + """Temporarily silence common transformers loading warnings.""" + logger = logging.getLogger("transformers.modeling_utils") + original_level = logger.getEffectiveLevel() + try: + logger.setLevel(logging.ERROR) + yield + finally: + logger.setLevel(original_level) + + +# Default model recommended in the original implementation. +_DEFAULT_MODEL = "bert-base-uncased" + +if _TRANSFORMERS_GREATER_EQUAL_4_4: + from transformers import AutoModel, AutoTokenizer + + def _download_model_for_depth_score() -> None: + """Download intensive operations.""" + with _ignore_transformers_finetune_warning(): + AutoTokenizer.from_pretrained(_DEFAULT_MODEL) + AutoModel.from_pretrained(_DEFAULT_MODEL) + + if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_model_for_depth_score): + __doctest_skip__ = ["depth_score"] +else: + __doctest_skip__ = ["depth_score"] + + +def _preprocess_multiple_references( + preds: List[str], target: List[Union[str, Sequence[str]]] +) -> Tuple[List[str], List[str], Optional[List[Tuple[int, int]]]]: + """Preprocess predictions and targets when dealing with multiple references. + + This function handles the case where a single prediction might have multiple + reference targets (represented as a list/tuple of strings). It flattens the + multi-reference structure into aligned (pred, ref) pairs and returns group + boundaries so the final distance can later be reduced per original prediction. + + Args: + preds: A list of predictions. + target: A list of targets, where each item could be a string or a list/tuple of strings. + + Returns: + Tuple: (preds, target, ref_group_boundaries) + - preds: Flattened list of `str` where each prediction is repeated once per reference. + - target: Flattened list of `str` containing all references. + - ref_group_boundaries: List of tuples (start, end) indicating the boundaries of each + original prediction's reference group in the flattened lists, or `None` if no + multi-reference structure is present. + + Raises: + ValueError: + If `preds` is not a list of strings. + + """ + if not all(isinstance(item, str) for item in preds): + raise ValueError("Invalid input provided.") + + has_nested = any(isinstance(item, (list, tuple)) for item in target) + if not has_nested: + return preds, cast(List[str], target), None + + ref_group_boundaries: List[Tuple[int, int]] = [] + new_preds: List[str] = [] + new_target: List[str] = [] + count = 0 + + for pred, ref_group in zip(preds, target): + if isinstance(ref_group, (list, tuple)): + new_preds.extend([pred] * len(ref_group)) + new_target.extend(cast(List[str], ref_group)) + ref_group_boundaries.append((count, count + len(ref_group))) + count += len(ref_group) + else: + new_preds.append(pred) + new_target.append(cast(str, ref_group)) + ref_group_boundaries.append((count, count + 1)) + count += 1 + + return new_preds, new_target, ref_group_boundaries + + +def _postprocess_multiple_references_distance( + distances: Tensor, + ref_group_boundaries: List[Tuple[int, int]], + reduction: str = "min", +) -> Tensor: + """Postprocess distances when dealing with multiple references. + + After `_preprocess_multiple_references` flattens multi-reference inputs, this function + reduces the computed per-(pred, ref) distances back to a single distance per original + prediction by aggregating within each reference group. + + Since DepthScore is a distance (lower is better), the default behavior uses `min` + (best matching reference). Other reductions can be used for diagnostics. + + Args: + distances: A 1D tensor of distances aligned with the flattened (pred, ref) pairs. + ref_group_boundaries: List of tuples (start, end) indicating the boundaries of each + original prediction's reference group in `distances`. + reduction: Reduction to apply within each group. One of `{"min", "max", "mean"}`. + - `"min"`: best reference match (default for distance metrics) + - `"max"`: worst reference match + - `"mean"`: average across references + + Returns: + A 1D tensor of shape `(num_predictions,)` containing the reduced distance per prediction. + + Raises: + ValueError: + If `distances` is not 1D. + ValueError: + If `reduction` is not one of `{"min","max","mean"}`. + + """ + if distances.dim() != 1: + raise ValueError("Expected 1D tensor of distances.") + if reduction not in {"min", "max", "mean"}: + raise ValueError("reduction must be one of {'min','max','mean'}.") + + out: List[Tensor] = [] + for start, end in ref_group_boundaries: + chunk = distances[start:end] + if reduction == "min": + out.append(chunk.min()) + elif reduction == "max": + out.append(chunk.max()) + else: + out.append(chunk.mean()) + return torch.stack(out, dim=0) + + +def cov_matrix(x: np.ndarray, robust: bool = False) -> np.ndarray: + """Covariance matrix (optionally robust).""" + if robust: + if not _SKLEARN_AVAILABLE: + raise ModuleNotFoundError( + "Robust covariance requires that `scikit-learn` is installed. " + "Use `pip install scikit-learn` or `pip install torchmetrics[text]`." + ) + from sklearn.covariance import MinCovDet as MCD # noqa: N817 + + return MCD().fit(x).covariance_ + return np.cov(x.T) + + +def standardize(x: np.ndarray, robust: bool = False) -> np.ndarray: + """Affine standardization using inverse sqrt covariance.""" + sigma = cov_matrix(x, robust) + _, n_features = x.shape + rank = np.linalg.matrix_rank(x) + + if rank < n_features: + if not _SKLEARN_AVAILABLE: + raise ModuleNotFoundError( + "Affine-invariant DepthScore requires that `scikit-learn` is installed. " + "Use `pip install scikit-learn` or `pip install torchmetrics[text]`." + ) + from sklearn.decomposition import PCA + + x = PCA(rank).fit_transform(x) + sigma = cov_matrix(x) + + u, s, _ = np.linalg.svd(sigma) + square_inv = u / np.sqrt(s) + return x @ square_inv + + +def sampled_sphere(n_dirs: int, d: int) -> np.ndarray: + """Uniform samples on unit sphere.""" + u = np.random.multivariate_normal(np.zeros(d), np.eye(d), size=n_dirs) + # The reference implementation uses `sklearn.preprocessing.normalize`. Here, that is mocked + # so default irw metric runs without any additional dependencies being installed. + return _normalize_l2_rows_exact(u) + + +def _normalize_l2_rows_exact(x: np.ndarray) -> np.ndarray: + norms = np.sqrt(np.einsum("ij,ij->i", x, x)) + norms[norms == 0.0] = 1.0 + return x / norms[:, None] + + +def wasserstein(x: np.ndarray, y: np.ndarray) -> float: + """Optimal transport cost with uniform weights.""" + if not _POT_AVAILABLE: + raise ModuleNotFoundError( + "The `wasserstein` backend requires that `POT` is installed. " + "Use `pip install POT` or `pip install torchmetrics[text]`." + ) + import ot # pip install POT # codespell:ignore ot + + m = ot.dist(x, y) # codespell:ignore ot + w_x = np.ones(len(x)) / len(x) + w_y = np.ones(len(y)) / len(y) + return float(ot.emd2(w_x, w_y, m)) # codespell:ignore ot + + +def sw(x: np.ndarray, y: np.ndarray, ndirs: int, p: int = 2) -> float: + """Sliced Wasserstein distance.""" + if not _POT_AVAILABLE: + raise ModuleNotFoundError( + "The `sliced` backend requires that `POT` is installed. " + "Use `pip install POT` or `pip install torchmetrics[text]`." + ) + import ot # pip install POT # codespell:ignore ot + + n, d = x.shape + u = sampled_sphere(ndirs, d) + z_x = x @ u.T + z_y = y @ u.T + sliced = np.zeros(ndirs) + for k in range(ndirs): + sliced[k] = ot.emd2_1d(z_x[:, k], z_y[:, k], p=2) # codespell:ignore ot + return float((np.mean(sliced)) ** (1 / p)) + + +def mmd(x: np.ndarray, y: np.ndarray) -> float: + """Gaussian MMD via geomloss.""" + if not _GEOMLOSS_AVAILABLE: + raise ModuleNotFoundError( + "The `mmd` backend requires that `geomloss` is installed. " + "Use `pip install geomloss` or `pip install torchmetrics[text]`." + ) + import geomloss + + return float(geomloss.SamplesLoss("gaussian")(torch.tensor(x), torch.tensor(y)).item()) + + +def ai_irw( + x: np.ndarray, ai: bool = True, robust: bool = False, n_dirs: Optional[int] = None, random_state: int = 0 +) -> np.ndarray: + """(Affine-invariant) integrated rank-weighted depth.""" + np.random.seed(random_state) + if ai: + x = standardize(x, robust) + + n, d = x.shape + n_dirs = d * 100 if n_dirs is None else n_dirs + + u = sampled_sphere(n_dirs, d) + proj = x @ u.T + ranks = np.argsort(proj, axis=0) + + depth = np.zeros_like(proj) + for k in range(n_dirs): + depth[ranks[:, k], k] = np.arange(1, n + 1) + + depth = depth / n + depth = np.minimum(depth, 1 - depth) + return np.mean(depth, axis=1) + + +def dr_distance( + x: np.ndarray, + y: np.ndarray, + n_alpha: int = 5, + n_dirs: int = 10000, + data_depth: str = "irw", + eps_min: float = 0.3, + eps_max: float = 1.0, + p: int = 5, + random_state: int = 0, +) -> float: + """Compute the depth-based pseudo-metric between two point clouds. + + This function implements the DepthScore "DR distance" between two empirical + distributions represented by token-embedding point clouds `x` and `y`. The distance + is computed by (1) choosing a data depth / distributional discrepancy backend + (e.g., IRW depth, affine-invariant IRW, Wasserstein, sliced Wasserstein, or MMD), + and (2) integrating over depth level sets between `eps_min` and `eps_max`, while + approximating the supremum over directions on the unit sphere by Monte Carlo. + + Args: + x: Array of shape `(n_samples, n_features)` representing the first point cloud. + y: Array of shape `(n_samples, n_features)` representing the second point cloud. + n_alpha: Monte-Carlo parameter controlling the approximation of the integral + over alpha (number of level-set thresholds between `eps_min` and `eps_max`). + n_dirs: Number of random directions used to approximate the supremum over the + unit sphere (and for depth estimation when applicable). + data_depth: Depth / discrepancy measure to use. One of + `{"irw", "ai_irw", "wasserstein", "sliced", "mmd"}`. + - `"irw"` / `"ai_irw"` compute depth values and then integrate level sets. + - `"wasserstein"` returns the (unsliced) OT cost directly. # codespell:ignore ot + - `"sliced"` returns the sliced Wasserstein distance directly. + - `"mmd"` returns the Gaussian MMD directly. + eps_min: Lower level-set bound in `[0, eps_max]` (lowest alpha / quantile level). + eps_max: Upper level-set bound in `[eps_min, 1]` (highest alpha / quantile level). + p: Power used in the ground cost aggregation (corresponds to the exponent in the + reference implementation). + random_state: Random seed controlling direction sampling and any stochastic steps. + + Returns: + The computed pseudo-metric score as a Python `float`. + + Raises: + ValueError: + If `data_depth` is unsupported. + ValueError: + If `eps_min` and `eps_max` do not satisfy `0 <= eps_min <= eps_max <= 1`. + + """ + np.random.seed(random_state) + + # Match reference numerics: many reference code paths end up in float64. + x = np.asarray(x, dtype=np.float64) + y = np.asarray(y, dtype=np.float64) + + if data_depth == "irw": + depth_x = ai_irw(x, ai=False, n_dirs=n_dirs, random_state=random_state) + depth_y = ai_irw(y, ai=False, n_dirs=n_dirs, random_state=random_state) + elif data_depth == "ai_irw": + depth_x = ai_irw(x, ai=True, n_dirs=n_dirs, random_state=random_state) + depth_y = ai_irw(y, ai=True, n_dirs=n_dirs, random_state=random_state) + elif data_depth == "wasserstein": + return wasserstein(x, y) + elif data_depth == "sliced": + return sw(x, y, ndirs=n_dirs) + elif data_depth == "mmd": + return mmd(x, y) + else: + raise ValueError("Unsupported depth") + + if not (0.0 <= eps_min <= eps_max <= 1.0): + raise ValueError("Expected 0 <= eps_min <= eps_max <= 1") + + _, d = x.shape + u = sampled_sphere(n_dirs, d) + proj_x = x @ u.T + proj_y = y @ u.T + + alphas = np.linspace(int(eps_min * 100), int(eps_max * 100), n_alpha) + q_x = [np.percentile(depth_x, a) for a in alphas] + q_y = [np.percentile(depth_y, a) for a in alphas] + + score = 0.0 + for i in range(n_alpha): + idx_x = np.where(depth_x >= q_x[i])[0] + idx_y = np.where(depth_y >= q_y[i])[0] + supp_x = np.max(proj_x[idx_x], axis=0) + supp_y = np.max(proj_y[idx_y], axis=0) + score += float(np.max((supp_x - supp_y) ** p)) + + return float((score / n_alpha) ** (1 / p)) + + +def _get_embeddings_and_mask( + dataloader: DataLoader, + target_len: int, + model: Module, + device: Optional[Union[str, torch.device]] = None, + num_layers: Optional[int] = None, + all_layers: bool = False, + verbose: bool = False, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, +) -> Tuple[Tensor, Tensor]: + """Compute normalized token embeddings and the corresponding attention mask. + + Args: + dataloader: Dataloader over `TextDataset` or `TokenizedDataset`. + target_len: Length of the longest sequence in the dataset (used for output collation/padding). + model: Transformer model used for embedding extraction. + device: Device to run inference on. + num_layers: Which hidden layer to use from `output_hidden_states`. + If `None`, the last layer is used. + all_layers: Whether to use representations from all layers. + If `True`, `num_layers` is ignored. + verbose: Whether to show a progress bar during embedding extraction. + user_forward_fn: + Optional user-defined forward function. If provided, it must: + - accept `(model, batch_dict)` where `batch_dict` contains `"input_ids"` and `"attention_mask"` + - return a tensor shaped like `(batch, seq_len, hidden_dim)`. + + Returns: + A tuple `(embeddings, attention_mask)` where: + - embeddings: Tensor shaped `(batch, 1, seq_len, hidden_dim)` when `all_layers=False`, + or `(batch, num_layers, seq_len, hidden_dim)` when `all_layers=True`. + Embeddings are L2-normalized over the hidden dimension and masked by `attention_mask`. + - attention_mask: Tensor shaped `(batch, seq_len)` aligned with `embeddings`. + + Raises: + ValueError: + If `user_forward_fn` output shape does not match the expected model output shape. + ValueError: + If `all_layers=True` is used with a custom `user_forward_fn`. + + """ + embeddings_list: List[Tensor] = [] + mask_list: List[Tensor] = [] + + for batch in _get_progress_bar(dataloader, verbose): + with torch.no_grad(): + batch = _input_data_collator(batch, device) + + if not all_layers: + if user_forward_fn is None: + out = model(batch["input_ids"], batch["attention_mask"], output_hidden_states=True) + hs = out.hidden_states[num_layers if num_layers is not None else -1] + else: + hs = user_forward_fn(model, batch) + _check_shape_of_model_output(hs, batch["input_ids"]) + # unify to (b, 1, s, d) like BERTScore's internal shape + hs = hs.unsqueeze(1) + else: + if user_forward_fn is not None: + raise ValueError( + "The option `all_layers=True` can be used only with default `transformers` models." + ) + out = model(batch["input_ids"], batch["attention_mask"], output_hidden_states=True) + hs = torch.cat([o.unsqueeze(1) for o in out.hidden_states], dim=1) + + # normalize embeddings (safe) + denom = hs.norm(dim=-1).unsqueeze(-1).clamp_min(1e-12) + hs = hs / denom + + hs, attention_mask = _output_data_collator(hs, batch["attention_mask"], target_len) + + # mask out padding/special tokens + hs = torch.einsum("blsd, bs -> blsd", hs, attention_mask) + + embeddings_list.append(hs.cpu()) + mask_list.append(attention_mask.cpu()) + + return torch.cat(embeddings_list, dim=0), torch.cat(mask_list, dim=0) + + +def depth_score( + preds: Union[str, Sequence[str], dict[str, Tensor]], + target: Union[str, Sequence[str], Sequence[Sequence[str]], dict[str, Tensor]], + model_name_or_path: Optional[str] = None, + num_layers: Optional[int] = None, + all_layers: bool = False, + model: Optional[Module] = None, + user_tokenizer: Any = None, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, + verbose: bool = False, + device: Optional[Union[str, torch.device]] = None, + max_length: int = 512, + batch_size: int = 64, + num_threads: int = 0, + truncation: bool = False, + # DepthScore-specific knobs + n_alpha: int = 5, + n_dirs: int = 10000, + eps: float = 0.3, + p: int = 5, + depth_measure: str = "irw", + # Multi-ref postprocess for a distance metric (best = min by default) + multi_ref_reduction: str = "min", +) -> Tensor: + """`DepthScore Evaluating Text Generation`_ for text similarity matching. + + DepthScore measures the distance between two sentences by comparing the distributions + of their contextual token embeddings using a depth-based pseudo-metric. Lower values + indicate that the predicted sentence is closer to the reference sentence. + + This implementation follows the original implementation from `DEPTH_score`_. + + Args: + preds: Predicted sentence(s) as `str`, `Sequence[str]`, or tokenized dict + containing `"input_ids"` and `"attention_mask"`. + target: Reference sentence(s) as `str`, `Sequence[str]`, multi-reference + `Sequence[Sequence[str]]`, or tokenized dict containing `"input_ids"` and `"attention_mask"`. + model_name_or_path: Hugging Face model name/path used when `model` is not provided. + num_layers: Hidden layer index to use for contextual embeddings. If `None`, the last layer is used. + all_layers: + An indication of whether the representation from all model's layers should be used. + If ``all_layers=True``, the argument ``num_layers`` is ignored. + model: Optional user-provided model. If provided, `user_tokenizer` must also be provided. + user_tokenizer: Tokenizer to use with a user-provided model. Ignored when `model` is `None`. + user_forward_fn: + Optional user-defined forward function producing embeddings from `(model, batch_dict)`. + verbose: Whether to show a progress bar during embedding extraction. + device: Device to run embedding extraction on. + max_length: Maximum input sequence length. Longer sequences are trimmed if `truncation=True`. + batch_size: Batch size used for model processing. + num_threads: Number of dataloader workers. + truncation: Whether to truncate input sequences to `max_length`. + n_alpha: Number of alpha levels used by the depth-based distance computation. + n_dirs: Number of random projection directions used by depth/sliced computations. + eps: Lower quantile bound (eps_min) used in the depth distance integration (upper bound fixed at 1.0). + p: Power used in the distance aggregation. + depth_measure: Depth/distance backend to use. One of: + `"irw"`, `"ai_irw"`, `"wasserstein"`, `"sliced"`, `"mmd"`. + multi_ref_reduction: Reduction to apply across multiple references per prediction. + Default `"min"` (best match) since this is a distance metric. + + Returns: + A 1D tensor of distances of shape `(num_predictions,)`. For multi-reference input, + the output is reduced per original prediction according to `multi_ref_reduction`. + + Raises: + ValueError: + If `len(preds) != len(target)`. + ModuleNotFoundError: + If `verbose=True` but `tqdm` is not installed. + ModuleNotFoundError: + If default transformers model is required but `transformers` is not installed. + ValueError: + If invalid input is provided for `preds`/`target`. + ValueError: + If `num_layers` is larger than the number of model layers (when detectable). + + Example: + >>> from torchmetrics.functional.text.depth_score import depth_score + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> depth_score(preds, target, model_name_or_path="distilbert-base-uncased", num_layers=4, device="cpu") + tensor([...]) + + Example: + >>> from torchmetrics.functional.text.depth_score import depth_score + >>> preds = ["hello there", "general kenobi"] + >>> target = [["hello there", "master kenobi"], ["hello there", "master kenobi"]] + >>> depth_score(preds, target, model_name_or_path="distilbert-base-uncased", num_layers=4, device="cpu") + tensor([...]) + + """ + ref_group_boundaries: Optional[List[Tuple[int, int]]] = None + + if isinstance(preds, str): + preds = [preds] + if isinstance(target, str): + target = [target] + if not isinstance(preds, (list, dict)): + preds = list(preds) + if not isinstance(target, (list, dict)): + target = list(target) + + if len(preds) != len(target): + raise ValueError( + "Expected number of predicted and reference sentences to be the same, but got" + f" {len(preds)} and {len(target)}" + ) + + if isinstance(preds, list) and len(preds) > 0 and isinstance(target, list) and len(target) > 0: + preds, target, ref_group_boundaries = _preprocess_multiple_references(preds, target) + + if verbose and (not _TQDM_AVAILABLE): + raise ModuleNotFoundError( + "An argument `verbose = True` requires `tqdm` package be installed. Install with `pip install tqdm`." + ) + + if model is None: + if not _TRANSFORMERS_GREATER_EQUAL_4_4: + raise ModuleNotFoundError( + "`depth_score` metric with default models requires `transformers` package be installed." + " Either install with `pip install transformers>=4.4` or `pip install torchmetrics[text]`." + ) + if model_name_or_path is None: + rank_zero_warn( + "The argument `model_name_or_path` was not specified while it is required when default" + " `transformers` model are used." + f" It is, therefore, used the default recommended model - {_DEFAULT_MODEL}." + ) + from transformers import AutoModel, AutoTokenizer + + with _ignore_transformers_finetune_warning(): + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path or _DEFAULT_MODEL) + model = AutoModel.from_pretrained(model_name_or_path or _DEFAULT_MODEL) + else: + if user_tokenizer is None: + raise ValueError("When `model` is provided, `user_tokenizer` must also be provided.") + tokenizer = user_tokenizer + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model.eval() + model.to(device) + + try: + if hasattr(model.config, "num_hidden_layers") and isinstance(model.config.num_hidden_layers, int): + if num_layers and num_layers > model.config.num_hidden_layers: + raise ValueError( + f"num_layers={num_layers} is forbidden for {model_name_or_path}." + f" Please use num_layers <= {model.config.num_hidden_layers}" + ) + else: + rank_zero_warn( + "Model config does not have `num_hidden_layers` as an integer attribute. " + "Unable to validate `num_layers`." + ) + except AttributeError: + rank_zero_warn("It was not possible to retrieve the parameter `num_layers` from the model specification.") + + _are_empty_lists = all(isinstance(text, list) and len(text) == 0 for text in (preds, target)) + _are_valid_lists = all( + isinstance(text, list) and len(text) > 0 and isinstance(text[0], str) for text in (preds, target) + ) + _are_valid_tensors = all( + isinstance(text, dict) and isinstance(text["input_ids"], Tensor) for text in (preds, target) + ) + + if _are_empty_lists: + rank_zero_warn("Predictions and references are empty.") + return torch.zeros(1, dtype=torch.float32) + + if _are_valid_lists: + target_dataset = TextDataset(target, tokenizer, max_length, truncation=truncation) # type: ignore + preds_dataset = TextDataset(preds, tokenizer, max_length, truncation=truncation) # type: ignore + + elif _are_valid_tensors: + target_dataset = TokenizedDataset(**target) # type: ignore + preds_dataset = TokenizedDataset(**preds) # type: ignore + else: + raise ValueError("Invalid input provided.") + + target_loader = DataLoader(target_dataset, batch_size=batch_size, num_workers=num_threads) + preds_loader = DataLoader(preds_dataset, batch_size=batch_size, num_workers=num_threads) + + target_embeddings, target_mask = _get_embeddings_and_mask( + target_loader, + target_dataset.max_length, + model, + device=device, + num_layers=num_layers, + all_layers=all_layers, + verbose=verbose, + user_forward_fn=user_forward_fn, + ) + preds_embeddings, preds_mask = _get_embeddings_and_mask( + preds_loader, + preds_dataset.max_length, + model, + device=device, + num_layers=num_layers, + all_layers=all_layers, + verbose=verbose, + user_forward_fn=user_forward_fn, + ) + + # Reorder back (TextDataset sorts by length internally) + target_embeddings = target_embeddings[target_loader.dataset.sorting_indices] + preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices] + target_mask = target_mask[target_loader.dataset.sorting_indices] + preds_mask = preds_mask[preds_loader.dataset.sorting_indices] + + # Pairwise (same index) distances + distances: List[float] = [] + n = preds_embeddings.shape[0] + + for i in range(n): + pm = preds_mask[i].bool() + tm = target_mask[i].bool() + + x = preds_embeddings[i, 0, pm, :].numpy() + y = target_embeddings[i, 0, tm, :].numpy() + + if x.shape[0] == 0 or y.shape[0] == 0: + distances.append(float("inf")) + continue + + distances.append( + dr_distance( + x, + y, + n_alpha=n_alpha, + n_dirs=n_dirs, + data_depth=depth_measure, + eps_min=eps, + eps_max=1.0, + p=p, + random_state=0, + ) + ) + + out = torch.tensor(distances, dtype=torch.float32) + + # Multi-reference reduction (distance metric: default "min" = best ref) + if ref_group_boundaries is not None: + out = _postprocess_multiple_references_distance(out, ref_group_boundaries, reduction=multi_ref_reduction) + + return out diff --git a/src/torchmetrics/text/__init__.py b/src/torchmetrics/text/__init__.py index 6af056246cd..ca297acc152 100644 --- a/src/torchmetrics/text/__init__.py +++ b/src/torchmetrics/text/__init__.py @@ -46,6 +46,7 @@ if _TRANSFORMERS_GREATER_EQUAL_4_4: from torchmetrics.text.bert import BERTScore + from torchmetrics.text.depth_score import DepthScore from torchmetrics.text.infolm import InfoLM - __all__ += ["BERTScore", "InfoLM"] + __all__ += ["BERTScore", "DepthScore", "InfoLM"] diff --git a/src/torchmetrics/text/depth_score.py b/src/torchmetrics/text/depth_score.py new file mode 100644 index 00000000000..5d8b6f5ec0f --- /dev/null +++ b/src/torchmetrics/text/depth_score.py @@ -0,0 +1,357 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Any, Callable, List, Optional, Tuple, Union, cast + +import torch +from torch import Tensor +from torch.nn import Module + +from torchmetrics.functional.text.depth_score import ( + _postprocess_multiple_references_distance, + _preprocess_multiple_references, + depth_score, +) +from torchmetrics.functional.text.helper_embedding_metric import _preprocess_text +from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_GREATER_EQUAL_4_4 +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["DepthScore.plot"] + +# Default model recommended in the original implementation. +_DEFAULT_MODEL: str = "bert-base-uncased" + +if _SKIP_SLOW_DOCTEST and _TRANSFORMERS_GREATER_EQUAL_4_4: + from transformers import AutoModel, AutoTokenizer + + def _download_model_for_depth_score() -> None: + """Download intensive operations.""" + AutoTokenizer.from_pretrained(_DEFAULT_MODEL, resume_download=True) + AutoModel.from_pretrained(_DEFAULT_MODEL, resume_download=True) + + if not _try_proceed_with_timeout(_download_model_for_depth_score): + __doctest_skip__ = ["DepthScore", "DepthScore.plot"] +else: + __doctest_skip__ = ["DepthScore", "DepthScore.plot"] + + +class DepthScore(Metric): + """`DepthScore Evaluating Text Generation`_ for measuring text similarity. + + DepthScore leverages pre-trained contextual token embeddings (e.g., from BERT-like models) and compares + candidate and reference sentences by treating their token embeddings as point clouds and computing a depth- + based pseudo-metric between the two distributions. This distance is designed to capture distributional + mismatches between contextual representations and can be used for evaluating text generation tasks where + *lower* distance indicates a better match. + + This implementation follows the original implementation from `DEPTH_score`_. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds``: Predicted sentence(s). Can be one of: + + * A single predicted sentence as a string (``str``) + * A sequence of predicted sentences (``Sequence[str]``) + + - ``target``: Target/reference sentence(s). Can be one of: + + * A single reference sentence as a string (``str``) + * A sequence of reference sentences (``Sequence[str]``) + * A sequence of sequences of reference sentences for multi-reference evaluation (``Sequence[Sequence[str]]``) + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``score`` (:class:`~torch.Tensor`): A 1D tensor of distances of shape `(num_predictions,)`. For multi-reference + input, the output is reduced per original prediction according to `multi_ref_reduction`. + + Args: + preds (Union[str, Sequence[str]]): A single predicted sentence or a sequence of predicted sentences. + target (Union[str, Sequence[str], Sequence[Sequence[str]]]): A single target sentence, a sequence of target + sentences, or a sequence of sequences of target sentences for multiple references per prediction. + model_name_or_path: A name or a model path used to load ``transformers`` pretrained model. + num_layers: A layer of representation to use. + all_layers: + An indication of whether the representation from all model's layers should be used. + If ``all_layers=True``, the argument ``num_layers`` is ignored. + model: A user's own model. Must be of `torch.nn.Module` instance. + user_tokenizer: + A user's own tokenizer used with the own model. This must be an instance with the ``__call__`` method. + This method must take an iterable of sentences (`List[str]`) and must return a python dictionary + containing `"input_ids"` and `"attention_mask"` represented by :class:`~torch.Tensor`. + It is up to the user's model of whether `"input_ids"` is a :class:`~torch.Tensor` of input ids or embedding + vectors. This tokenizer must prepend an equivalent of ``[CLS]`` token and append an equivalent of ``[SEP]`` + token as ``transformers`` tokenizer does. + user_forward_fn: + A user's own forward function used in a combination with ``user_model``. This function must take + ``user_model`` and a python dictionary of containing ``"input_ids"`` and ``"attention_mask"`` represented + by :class:`~torch.Tensor` as an input and return the model's output represented by the single + :class:`~torch.Tensor`. + verbose: An indication of whether a progress bar to be displayed during the embeddings' calculation. + device: A device to be used for calculation. + max_length: A maximum length of input sequences. Sequences longer than ``max_length`` are to be trimmed. + batch_size: A batch size used for model processing. + num_threads: A number of threads to use for a dataloader. + n_alpha: The Monte-Carlo parameter for the approximation of the integral over alpha (number of level-set + thresholds between ``eps`` and 1.0). + eps: The lowest level-set bound in [0, 1]. The highest level set is fixed to 1.0 in this implementation. + p: The power of the ground cost. + depth_measure: Depth / discrepancy measure to use (e.g. ``"irw"`` or ``"ai_irw"``). + truncation: An indication of whether the input sequences should be truncated to the ``max_length``. + multi_ref_reduction: Reduction to apply across multiple references per prediction. + Default ``"min"`` (best match) since this is a distance metric. Options: ``"min"``, ``"max"``, ``"mean"``. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from pprint import pprint + >>> from torchmetrics.text.depth_score import DepthScore + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> depthscore = DepthScore() + >>> pprint(depthscore(preds, target)) + tensor([...]) + + Example: + >>> from pprint import pprint + >>> from torchmetrics.text.depth_score import DepthScore + >>> preds = ["hello there", "general kenobi"] + >>> target = [["hello there", "master kenobi"], ["hello there", "master kenobi"]] + >>> depthscore = DepthScore() + >>> pprint(depthscore(preds, target)) + tensor([...]) + + """ + + is_differentiable: bool = False + higher_is_better: bool = False # distance + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 # not truly bounded; used only for plotting convenience + + preds_input_ids: List[Tensor] + preds_attention_mask: List[Tensor] + target_input_ids: List[Tensor] + target_attention_mask: List[Tensor] + + def __init__( + self, + model_name_or_path: Optional[str] = None, + num_layers: Optional[int] = None, + all_layers: bool = False, + model: Optional[Module] = None, + user_tokenizer: Optional[Any] = None, + user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None, + verbose: bool = False, + device: Optional[Union[str, torch.device]] = None, + max_length: int = 512, + batch_size: int = 64, + num_threads: int = 0, + n_alpha: int = 5, + eps: float = 0.3, + p: int = 5, + depth_measure: str = "irw", + truncation: bool = False, + multi_ref_reduction: str = "min", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + if not _TRANSFORMERS_GREATER_EQUAL_4_4 and user_tokenizer is None: + raise ModuleNotFoundError( + "`DepthScore` metric with default tokenizers requires `transformers` package be installed." + " Either install with `pip install transformers>=4.4` or `pip install torchmetrics[text]`." + ) + + self.model_name_or_path = model_name_or_path or _DEFAULT_MODEL + self.num_layers = num_layers + self.all_layers = all_layers + self.model = model + self.user_forward_fn = user_forward_fn + self.verbose = verbose + self.embedding_device = device + self.max_length = max_length + self.batch_size = batch_size + self.num_threads = num_threads + self.n_alpha = n_alpha + self.eps = eps + self.p = p + self.depth_measure = depth_measure + self.truncation = truncation + self.multi_ref_reduction = multi_ref_reduction + + self.ref_group_boundaries: Optional[List[Tuple[int, int]]] = None + + if user_tokenizer: + self.tokenizer = user_tokenizer + self.user_tokenizer = True + else: + from transformers import AutoTokenizer + + if model_name_or_path is None: + rank_zero_warn( + "The argument `model_name_or_path` was not specified while it is required when the default" + f" `transformers` model is used. It will use the default recommended model - {_DEFAULT_MODEL!r}." + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) + self.user_tokenizer = False + + self.add_state("preds_input_ids", [], dist_reduce_fx="cat") + self.add_state("preds_attention_mask", [], dist_reduce_fx="cat") + self.add_state("target_input_ids", [], dist_reduce_fx="cat") + self.add_state("target_attention_mask", [], dist_reduce_fx="cat") + + def update( + self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[str], Sequence[Sequence[str]]] + ) -> None: + """Store predictions/references for computing DepthScore. + + It is necessary to store sentences in a tokenized form to ensure the DDP mode working. + + """ + if isinstance(preds, str): + preds = [preds] + if isinstance(target, str): + target = [target] + if not isinstance(preds, list): + preds = list(preds) + if not isinstance(target, list): + target = list(target) + + if len(preds) != len(target): + raise ValueError( + "Expected number of predicted and reference sentences to be the same, but got" + f"{len(preds)} and {len(target)}" + ) + + if isinstance(preds, list) and len(preds) > 0 and isinstance(target, list) and len(target) > 0: + preds, target, self.ref_group_boundaries = _preprocess_multiple_references(preds, target) + + preds_dict, _ = _preprocess_text( + preds, + self.tokenizer, + self.max_length, + truncation=self.truncation, + sort_according_length=False, + own_tokenizer=self.user_tokenizer, + ) + target_dict, _ = _preprocess_text( + cast(List[str], target), + self.tokenizer, + self.max_length, + truncation=self.truncation, + sort_according_length=False, + own_tokenizer=self.user_tokenizer, + ) + + self.preds_input_ids.append(preds_dict["input_ids"]) + self.preds_attention_mask.append(preds_dict["attention_mask"]) + self.target_input_ids.append(target_dict["input_ids"]) + self.target_attention_mask.append(target_dict["attention_mask"]) + + def compute(self) -> Tensor: + """Calculate DepthScore.""" + preds = { + "input_ids": dim_zero_cat(self.preds_input_ids), + "attention_mask": dim_zero_cat(self.preds_attention_mask), + } + target = { + "input_ids": dim_zero_cat(self.target_input_ids), + "attention_mask": dim_zero_cat(self.target_attention_mask), + } + + out = depth_score( + preds=preds, # supports dict input (tokenized) + target=target, + model_name_or_path=self.model_name_or_path, + num_layers=self.num_layers, + all_layers=self.all_layers, + n_alpha=self.n_alpha, + eps=self.eps, + p=self.p, + depth_measure=self.depth_measure, + device=self.embedding_device if self.embedding_device is not None else None, + model=self.model, + user_tokenizer=self.tokenizer if self.user_tokenizer else None, + user_forward_fn=self.user_forward_fn, + max_length=self.max_length, + batch_size=self.batch_size, + num_threads=self.num_threads, + truncation=self.truncation, + verbose=self.verbose, + multi_ref_reduction=self.multi_ref_reduction, + ) + + # out expected: {"depth_score": Tensor} aligned with flattened refs if multi-ref used + if self.ref_group_boundaries is not None: + out = _postprocess_multiple_references_distance( + out, + self.ref_group_boundaries, + reduction=self.multi_ref_reduction, + ) + + return out + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: A matplotlib axis object. If provided will add plot to that axis. + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> from torchmetrics.text.depth_score import DepthScore + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> metric = DepthScore() + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> from torch import tensor + >>> from torchmetrics.text.depth_score import DepthScore + >>> preds = ["hello there", "general kenobi"] + >>> target = ["hello there", "master kenobi"] + >>> metric = DepthScore() + >>> values = [] + >>> for _ in range(10): + ... val = metric(preds, target) + ... val = val.mean() # convert into a single scalar + ... values.append(val) + >>> fig_, ax_ = metric.plot(values) + + """ + if val is None: # default average score across sentences + val = self.compute() + val = val.mean() + return self._plot(val, ax) diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 88da9076269..775aa57177e 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -29,8 +29,11 @@ _NLTK_AVAILABLE = RequirementCache("nltk") _ROUGE_SCORE_AVAILABLE = RequirementCache("rouge_score") _BERTSCORE_AVAILABLE = RequirementCache("bert_score") +_GEOMLOSS_AVAILABLE = RequirementCache("geomloss") +_POT_AVAILABLE = RequirementCache("POT") _SCIPY_AVAILABLE = RequirementCache("scipy") _SCIPY_GREATER_EQUAL_1_8 = RequirementCache("scipy>=1.8.0") +_SKLEARN_AVAILABLE = RequirementCache("scikit-learn") _TORCH_FIDELITY_AVAILABLE = RequirementCache("torch_fidelity") _LPIPS_AVAILABLE = RequirementCache("lpips") _PYCOCOTOOLS_AVAILABLE = RequirementCache("pycocotools") diff --git a/tests/unittests/text/test_depth_score.py b/tests/unittests/text/test_depth_score.py new file mode 100644 index 00000000000..8a5634c9e71 --- /dev/null +++ b/tests/unittests/text/test_depth_score.py @@ -0,0 +1,342 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Check if nlg_eval_via_simi_measures is available for reference metric tests +import importlib.util +import os +from collections.abc import Sequence +from functools import partial + +import pytest +import torch +from torch import Tensor + +from torchmetrics.functional.text.depth_score import depth_score +from torchmetrics.text.depth_score import DepthScore +from torchmetrics.utilities.imports import ( + _GEOMLOSS_AVAILABLE, + _POT_AVAILABLE, + _SKLEARN_AVAILABLE, + _TRANSFORMERS_GREATER_EQUAL_4_4, +) +from unittests._helpers import ( + _IS_WINDOWS, + _TORCH_LESS_THAN_2_1, + _TRANSFORMERS_GREATER_EQUAL_4_54, + _TRANSFORMERS_RANGE_GE_4_50_LT_4_54, + skip_on_connection_issues, +) +from unittests.text._helpers import TextTester +from unittests.text._inputs import ( + _inputs_multiple_references, + _inputs_single_reference, + _inputs_single_sentence_multiple_references, +) + +_NLG_EVAL_AVAILABLE = importlib.util.find_spec("nlg_eval_via_simi_measures") is not None + +MODEL_NAME = "albert-base-v2" + + +_DEPTH_MEASURES = [ + "irw", + pytest.param( + "ai_irw", + marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="test requires scikit-learn"), + ), + pytest.param( + "sliced", + marks=pytest.mark.skipif(not _POT_AVAILABLE, reason="test requires POT"), + ), + pytest.param( + "wasserstein", + marks=pytest.mark.skipif(not _POT_AVAILABLE, reason="test requires POT"), + ), + pytest.param( + "mmd", + marks=pytest.mark.skipif(not _GEOMLOSS_AVAILABLE, reason="test requires geomloss"), + ), +] + +# Disable tokenizers parallelism (forking not friendly with parallelism) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +def _reference_depth_score( + preds: Sequence[str], + target: Sequence[str], + num_layers: int, + depth_measure: str = "irw", +) -> Tensor: + # Reference source code depthscore implementation + from nlg_eval_via_simi_measures.depth_score import DepthScoreMetric + + metric_call = DepthScoreMetric(MODEL_NAME, layers_to_consider=num_layers, considered_measure=depth_measure) + out = metric_call.evaluate_batch(list(target), list(preds)) + return torch.as_tensor(out["depth_score"], dtype=torch.float32) + + +@pytest.mark.parametrize("num_layers", [4, 8]) +@pytest.mark.parametrize("depth_measure", _DEPTH_MEASURES) +@pytest.mark.parametrize( + ("preds", "targets"), + [(_inputs_single_reference.preds, _inputs_single_reference.target)], +) +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.xfail( + RuntimeError, + condition=_TORCH_LESS_THAN_2_1 and _TRANSFORMERS_RANGE_GE_4_50_LT_4_54, + reason="could be due to torch compatibility issues with transformers", +) +@pytest.mark.xfail( + ImportError, + condition=_TORCH_LESS_THAN_2_1 and _IS_WINDOWS and _TRANSFORMERS_GREATER_EQUAL_4_54, + reason="another strange behaviour of transformers on windows", +) +class TestDepthScore(TextTester): + """Tests for DepthScore.""" + + @pytest.mark.skipif(not _NLG_EVAL_AVAILABLE, reason="test requires nlg_eval_via_simi_measures to be installed") + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @skip_on_connection_issues() + def test_depthscore_class(self, ddp, preds, targets, num_layers, depth_measure): + """Test the depth score class.""" + metric_args = { + "model_name_or_path": MODEL_NAME, + "num_layers": num_layers, + "depth_measure": depth_measure, + "device": "cpu", + "batch_size": 8, + "max_length": 128, + "truncation": True, # nlg_eval reference always truncates + } + reference_depth_score_metric = partial( + _reference_depth_score, + num_layers=num_layers, + depth_measure=depth_measure, + ) + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + targets=targets, + metric_class=DepthScore, + reference_metric=reference_depth_score_metric, + metric_args=metric_args, + check_scriptable=False, # huggingface transformers are not usually scriptable + ignore_order=ddp, # ignore order of predictions when DDP is used + ) + + @pytest.mark.skipif(not _NLG_EVAL_AVAILABLE, reason="test requires nlg_eval_via_simi_measures to be installed") + @skip_on_connection_issues() + def test_depthscore_functional(self, preds, targets, num_layers, depth_measure): + """Test the depthscore functional.""" + metric_args = { + "model_name_or_path": MODEL_NAME, + "num_layers": num_layers, + "depth_measure": depth_measure, + "truncation": True, # nlg_eval reference always truncates + } + reference_depth_score_metric = partial( + _reference_depth_score, + num_layers=num_layers, + depth_measure=depth_measure, + ) + + self.run_functional_metric_test( + preds, + targets, + metric_functional=depth_score, + reference_metric=reference_depth_score_metric, + metric_args=metric_args, + ) + + @skip_on_connection_issues() + def test_depthscore_differentiability(self, preds, targets, num_layers, depth_measure): + """Test the depthscore differentiability.""" + metric_args = { + "model_name_or_path": MODEL_NAME, + "num_layers": num_layers, + "depth_measure": depth_measure, + "truncation": True, # nlg_eval reference always truncates + } + + self.run_differentiability_test( + preds=preds, + targets=targets, + metric_module=DepthScore, + metric_functional=depth_score, + metric_args=metric_args, + ) + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.xfail( + RuntimeError, + condition=_TORCH_LESS_THAN_2_1 and _TRANSFORMERS_RANGE_GE_4_50_LT_4_54, + reason="could be due to torch compatibility issues with transformers", +) +@pytest.mark.xfail( + ImportError, + condition=_TORCH_LESS_THAN_2_1 and _IS_WINDOWS and _TRANSFORMERS_GREATER_EQUAL_4_54, + reason="another strange behaviour of transformers on windows", +) +def test_depthscore_sorting(): + """Test that DepthScore is invariant to the order of the inputs.""" + short = "Short text" + long = "This is a longer text" + + preds = [long, long] + targets = [long, short] + + metric = DepthScore(model_name_or_path=MODEL_NAME, num_layers=4, device="cpu", batch_size=2, max_length=64) + score = metric(preds, targets) + + # First index should be the self-comparison - sorting by length should not shuffle this. + # Distance metric: self-comparison should have a smaller distance than mismatched pair. + assert score[0] < score[1] + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.xfail( + RuntimeError, + condition=_TORCH_LESS_THAN_2_1 and _TRANSFORMERS_RANGE_GE_4_50_LT_4_54, + reason="could be due to torch compatibility issues with transformers", +) +@pytest.mark.xfail( + ImportError, + condition=_TORCH_LESS_THAN_2_1 and _IS_WINDOWS and _TRANSFORMERS_GREATER_EQUAL_4_54, + reason="another strange behaviour of transformers on windows", +) +@pytest.mark.parametrize("truncation", [True, False]) +def test_depthscore_truncation(truncation: bool): + """Test that DepthScore truncation works as expected.""" + pred = ["abc " * 2000] + gt = ["def " * 2000] + metric = DepthScore( + model_name_or_path=MODEL_NAME, + num_layers=4, + device="cpu", + batch_size=1, + max_length=64, + truncation=truncation, + ) + + if truncation: + res = metric(pred, gt) + # Should produce a finite tensor (not error). Value itself is not bounded. + assert torch.isfinite(res).all() + else: + with pytest.raises(RuntimeError, match="The expanded size of the tensor.*must match.*"): + metric(pred, gt) + + +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.xfail( + RuntimeError, + condition=_TORCH_LESS_THAN_2_1 and _TRANSFORMERS_RANGE_GE_4_50_LT_4_54, + reason="could be due to torch compatibility issues with transformers", +) +@pytest.mark.xfail( + ImportError, + condition=_TORCH_LESS_THAN_2_1 and _IS_WINDOWS and _TRANSFORMERS_GREATER_EQUAL_4_54, + reason="another strange behaviour of transformers on windows", +) +def test_depthscore_single_str_input(): + """Test if DepthScore works with single string preds and target.""" + preds = "hello there" + target = "hello there" + + metric = DepthScore(model_name_or_path=MODEL_NAME, num_layers=4, device="cpu", batch_size=1, max_length=64) + score_class = metric(preds, target) + + # Distance for identical text should be smaller than for different text. + score_class_ident = score_class.item() + + score_functional = depth_score( + preds, + target, + model_name_or_path=MODEL_NAME, + num_layers=4, + device="cpu", + batch_size=1, + max_length=64, + ) + score_func_ident = score_functional.item() + + assert score_class_ident == pytest.approx(score_func_ident, abs=1e-6) + + # Compare to a different target to assert "identical is better" + score_diff = metric("hello there", "general kenobi").item() + assert score_class_ident <= score_diff + + +@pytest.mark.parametrize( + ("preds", "target"), + [ + ( + _inputs_single_sentence_multiple_references.preds, + _inputs_single_sentence_multiple_references.target, + ), + ( + ["hello there", "I'm in the middle", "general kenobi"], + (["hello there", "master kenobi"], "I'm here", ("hello there", "master kenobi")), + ), + ], +) +@skip_on_connection_issues() +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +@pytest.mark.xfail( + RuntimeError, + condition=_TORCH_LESS_THAN_2_1 and _TRANSFORMERS_RANGE_GE_4_50_LT_4_54, + reason="could be due to torch compatibility issues with transformers", +) +@pytest.mark.xfail( + ImportError, + condition=_TORCH_LESS_THAN_2_1 and _IS_WINDOWS and _TRANSFORMERS_GREATER_EQUAL_4_54, + reason="another strange behaviour of transformers on windows", +) +def test_depthscore_multiple_references(preds, target): + """Test both functional and class APIs with multiple references.""" + # Functional returns a 1D tensor; class returns dict with "depth_score" + result_func = depth_score(preds, target) + metric = DepthScore() + result_class = metric(preds, target) + + # They should match exactly (same code path), and output should be per-pred after reduction (min across refs). + assert torch.allclose(result_func, result_class, atol=1e-6) + + # Sanity: output length equals number of predictions (not flattened refs) + if isinstance(preds, str): + assert result_func.numel() == 1 + else: + assert result_func.numel() == len(preds) + + +@pytest.mark.skipif(not _TRANSFORMERS_GREATER_EQUAL_4_4, reason="test requires transformers>4.4") +def test_depthscore_invalid_references(): + """Test both functional and class APIs with invalid references.""" + preds = _inputs_multiple_references.preds + target = _inputs_multiple_references.target + + with pytest.raises(ValueError, match="Invalid input provided."): + depth_score(preds, target) + + metric = DepthScore() + with pytest.raises(ValueError, match="Invalid input provided."): + metric(preds, target)