From ded8cecdfac0e41c5192b8a0645ba5e8fb69103d Mon Sep 17 00:00:00 2001 From: Luca Butera <22855332+LucaButera@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:25:29 +0200 Subject: [PATCH 1/5] Added t_dim argument to Masked Metric to specify which is the time dimension --- tsl/metrics/torch/metric_base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tsl/metrics/torch/metric_base.py b/tsl/metrics/torch/metric_base.py index f187f03..4f72230 100644 --- a/tsl/metrics/torch/metric_base.py +++ b/tsl/metrics/torch/metric_base.py @@ -38,12 +38,14 @@ class MaskedMetric(Metric): sequences by accepting a boolean mask as additional input. Args: - metric_fn: Base function to compute the metric point wise. + metric_fn: Base function to compute the metric point-wise. mask_nans (bool, optional): Whether to automatically mask nan values. mask_inf (bool, optional): Whether to automatically mask infinite values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + t_dim (int): The index of the dimension that represents time in a batch. + Default assumes [b t n f] format, hence is 1. """ is_differentiable: bool = None @@ -57,6 +59,7 @@ def __init__(self, metric_fn_kwargs=None, at=None, full_state_update: bool = None, + t_dim: int = 1, **kwargs: Any): # set 'full_state_update' before Metric instantiation if full_state_update is not None: @@ -74,6 +77,7 @@ def __init__(self, self.at = slice(None) else: self.at = slice(at, at + 1) + self.t_dim = t_dim self.add_state('value', dist_reduce_fx='sum', default=torch.tensor(0., dtype=torch.float)) @@ -109,10 +113,10 @@ def is_masked(self, mask): return self.mask_inf or self.mask_nans or (mask is not None) def update(self, y_hat, y, mask=None): - y_hat = y_hat[:, self.at] - y = y[:, self.at] + y_hat = y_hat.select(self.t_dim, self.at) + y = y.select(self.t_dim, self.at) if mask is not None: - mask = mask[:, self.at] + mask = mask.select(self.t_dim, self.at) if self.is_masked(mask): val, numel = self._compute_masked(y_hat, y, mask) else: From 8727d10f241c86ccdcc6df691a111939687f571d Mon Sep 17 00:00:00 2001 From: Luca Butera <22855332+LucaButera@users.noreply.github.com> Date: Mon, 24 Apr 2023 12:58:16 +0200 Subject: [PATCH 2/5] Renamed MaskedMetric's t_dim parameter to dim --- tsl/metrics/torch/metric_base.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tsl/metrics/torch/metric_base.py b/tsl/metrics/torch/metric_base.py index 4f72230..6858f6e 100644 --- a/tsl/metrics/torch/metric_base.py +++ b/tsl/metrics/torch/metric_base.py @@ -44,7 +44,8 @@ class MaskedMetric(Metric): values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. - t_dim (int): The index of the dimension that represents time in a batch. + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. Default assumes [b t n f] format, hence is 1. """ @@ -59,7 +60,7 @@ def __init__(self, metric_fn_kwargs=None, at=None, full_state_update: bool = None, - t_dim: int = 1, + dim: int = 1, **kwargs: Any): # set 'full_state_update' before Metric instantiation if full_state_update is not None: @@ -77,7 +78,7 @@ def __init__(self, self.at = slice(None) else: self.at = slice(at, at + 1) - self.t_dim = t_dim + self.dim = dim self.add_state('value', dist_reduce_fx='sum', default=torch.tensor(0., dtype=torch.float)) @@ -113,10 +114,10 @@ def is_masked(self, mask): return self.mask_inf or self.mask_nans or (mask is not None) def update(self, y_hat, y, mask=None): - y_hat = y_hat.select(self.t_dim, self.at) - y = y.select(self.t_dim, self.at) + y_hat = y_hat.select(self.dim, self.at) + y = y.select(self.dim, self.at) if mask is not None: - mask = mask.select(self.t_dim, self.at) + mask = mask.select(self.dim, self.at) if self.is_masked(mask): val, numel = self._compute_masked(y_hat, y, mask) else: From d9ad5365e554f854945b0a5088d9f311115180a1 Mon Sep 17 00:00:00 2001 From: Luca Butera <22855332+LucaButera@users.noreply.github.com> Date: Mon, 24 Apr 2023 16:29:43 +0200 Subject: [PATCH 3/5] Fixed MaskedMetric to avoid using slices which do not work with torch.select --- tsl/metrics/torch/metric_base.py | 22 ++++++++++----------- tsl/metrics/torch/metrics.py | 32 +++++++++++++++++++++++++------ tsl/metrics/torch/pinball_loss.py | 9 +++++++-- 3 files changed, 43 insertions(+), 20 deletions(-) diff --git a/tsl/metrics/torch/metric_base.py b/tsl/metrics/torch/metric_base.py index 6858f6e..9a014df 100644 --- a/tsl/metrics/torch/metric_base.py +++ b/tsl/metrics/torch/metric_base.py @@ -1,7 +1,7 @@ import inspect from copy import deepcopy from functools import partial -from typing import Any +from typing import Any, Optional import torch from torchmetrics import Metric @@ -55,10 +55,10 @@ class MaskedMetric(Metric): def __init__(self, metric_fn, - mask_nans=False, - mask_inf=False, + mask_nans: Optional[bool] = False, + mask_inf: Optional[bool] = False, metric_fn_kwargs=None, - at=None, + at: Optional[int] = None, full_state_update: bool = None, dim: int = 1, **kwargs: Any): @@ -74,10 +74,7 @@ def __init__(self, self.mask_nans = mask_nans self.mask_inf = mask_inf - if at is None: - self.at = slice(None) - else: - self.at = slice(at, at + 1) + self.at = at self.dim = dim self.add_state('value', dist_reduce_fx='sum', @@ -114,10 +111,11 @@ def is_masked(self, mask): return self.mask_inf or self.mask_nans or (mask is not None) def update(self, y_hat, y, mask=None): - y_hat = y_hat.select(self.dim, self.at) - y = y.select(self.dim, self.at) - if mask is not None: - mask = mask.select(self.dim, self.at) + if self.at is not None: + y_hat = y_hat.select(self.dim, self.at) + y = y.select(self.dim, self.at) + if mask is not None: + mask = mask.select(self.dim, self.at) if self.is_masked(mask): val, numel = self._compute_masked(y_hat, y, mask) else: diff --git a/tsl/metrics/torch/metrics.py b/tsl/metrics/torch/metrics.py index 5018db2..15edb3c 100644 --- a/tsl/metrics/torch/metrics.py +++ b/tsl/metrics/torch/metrics.py @@ -18,7 +18,10 @@ class MaskedMAE(MaskedMetric): mask_inf (bool, optional): Whether to automatically mask infinite values. at (int, optional): Whether to compute the metric only w.r.t. a certain - time step. + time step. + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format, hence is 1. """ is_differentiable: bool = True @@ -29,12 +32,14 @@ def __init__(self, mask_nans=False, mask_inf=False, at=None, + dim: int = 1, **kwargs: Any): super(MaskedMAE, self).__init__(metric_fn=F.l1_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, + dim=dim, **kwargs) @@ -45,19 +50,23 @@ class MaskedMAPE(MaskedMetric): mask_nans (bool, optional): Whether to automatically mask nan values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format, hence is 1. """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False - def __init__(self, mask_nans=False, at=None, **kwargs: Any): + def __init__(self, mask_nans=False, at=None, dim: int = 1, **kwargs: Any): super(MaskedMAPE, self).__init__(metric_fn=mape, mask_nans=mask_nans, mask_inf=True, metric_fn_kwargs={'reduction': 'none'}, at=at, + dim=dim, **kwargs) @@ -70,6 +79,9 @@ class MaskedMSE(MaskedMetric): values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format, hence is 1. """ is_differentiable: bool = True @@ -80,12 +92,14 @@ def __init__(self, mask_nans=False, mask_inf=False, at=None, + dim: int = 1, **kwargs: Any): super(MaskedMSE, self).__init__(metric_fn=F.mse_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, + dim=dim, **kwargs) @@ -98,6 +112,9 @@ class MaskedMRE(MaskedMetric): values. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format, hence is 1. """ is_differentiable: bool = True @@ -108,12 +125,14 @@ def __init__(self, mask_nans=False, mask_inf=False, at=None, + dim: int = 1, **kwargs: Any): super(MaskedMRE, self).__init__(metric_fn=F.l1_loss, mask_nans=mask_nans, mask_inf=mask_inf, metric_fn_kwargs={'reduction': 'none'}, at=at, + dim=dim, **kwargs) self.add_state('tot', dist_reduce_fx='sum', @@ -138,10 +157,11 @@ def compute(self): return self.value def update(self, y_hat, y, mask=None): - y_hat = y_hat[:, self.at] - y = y[:, self.at] - if mask is not None: - mask = mask[:, self.at] + if self.at is not None: + y_hat = y_hat.select(self.dim, self.at) + y = y.select(self.dim, self.at) + if mask is not None: + mask = mask.select(self.dim, self.at) if self.is_masked(mask): val, numel, tot = self._compute_masked(y_hat, y, mask) else: diff --git a/tsl/metrics/torch/pinball_loss.py b/tsl/metrics/torch/pinball_loss.py index b530ff3..1fb79db 100644 --- a/tsl/metrics/torch/pinball_loss.py +++ b/tsl/metrics/torch/pinball_loss.py @@ -17,6 +17,9 @@ class MaskedPinballLoss(MaskedMetric): mini-batches. at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + dim (int): The index of the dimension that represents time in a batch. + Relevant only when also 'at' is defined. + Default assumes [b t n f] format, hence is 1. """ is_differentiable: bool = True @@ -31,7 +34,8 @@ def __init__(self, dist_sync_on_step=False, process_group=None, dist_sync_fn=None, - at=None): + at=None, + dim: int = 1): super(MaskedPinballLoss, self).__init__(metric_fn=pinball_loss, mask_nans=mask_nans, @@ -41,4 +45,5 @@ def __init__(self, process_group=process_group, dist_sync_fn=dist_sync_fn, metric_fn_kwargs={'q': q}, - at=at) + at=at, + dim=dim) From 25e6a84499b646827992b5248397c71fd0d221ac Mon Sep 17 00:00:00 2001 From: Luca Butera <22855332+LucaButera@users.noreply.github.com> Date: Wed, 26 Apr 2023 18:06:51 +0200 Subject: [PATCH 4/5] Improved docs and hints --- tsl/metrics/torch/metric_base.py | 31 ++++++++++----- tsl/metrics/torch/metrics.py | 63 ++++++++++++++++++++----------- tsl/metrics/torch/pinball_loss.py | 31 +++++++++------ 3 files changed, 81 insertions(+), 44 deletions(-) diff --git a/tsl/metrics/torch/metric_base.py b/tsl/metrics/torch/metric_base.py index 9a014df..b58fe32 100644 --- a/tsl/metrics/torch/metric_base.py +++ b/tsl/metrics/torch/metric_base.py @@ -1,7 +1,7 @@ import inspect from copy import deepcopy from functools import partial -from typing import Any, Optional +from typing import Any, Callable, Dict, Optional import torch from torchmetrics import Metric @@ -38,15 +38,26 @@ class MaskedMetric(Metric): sequences by accepting a boolean mask as additional input. Args: - metric_fn: Base function to compute the metric point-wise. - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + metric_fn (callable): Base function to compute the metric point-wise. + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. + (default: :obj:`False`) + metric_fn_kwargs (dict, optional): Keyword arguments needed by + :obj:`metric_fn`. + (default: :obj:`None`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) + full_state_update (bool, optional): Set this to overwrite the + :obj:`full_state_update` value of the + :obj:`torchmetrics.Metric` base class. + (default: :obj:`None`) dim (int): The index of the dimension that represents time in a batch. Relevant only when also 'at' is defined. - Default assumes [b t n f] format, hence is 1. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = None @@ -54,12 +65,12 @@ class MaskedMetric(Metric): full_state_update: bool = None def __init__(self, - metric_fn, - mask_nans: Optional[bool] = False, - mask_inf: Optional[bool] = False, - metric_fn_kwargs=None, + metric_fn: Callable, + mask_nans: bool = False, + mask_inf: bool = False, + metric_fn_kwargs: Optional[Dict[str, Any]] = None, at: Optional[int] = None, - full_state_update: bool = None, + full_state_update: Optional[bool] = None, dim: int = 1, **kwargs: Any): # set 'full_state_update' before Metric instantiation diff --git a/tsl/metrics/torch/metrics.py b/tsl/metrics/torch/metrics.py index 15edb3c..4716a8a 100644 --- a/tsl/metrics/torch/metrics.py +++ b/tsl/metrics/torch/metrics.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Optional import torch from torch.nn import functional as F @@ -14,14 +14,18 @@ class MaskedMAE(MaskedMetric): """Mean Absolute Error Metric. Args: - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. + (default: :obj:`False`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) dim (int): The index of the dimension that represents time in a batch. Relevant only when also 'at' is defined. - Default assumes [b t n f] format, hence is 1. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True @@ -29,9 +33,9 @@ class MaskedMAE(MaskedMetric): full_state_update: bool = False def __init__(self, - mask_nans=False, - mask_inf=False, - at=None, + mask_nans: bool = False, + mask_inf: bool = False, + at: Optional[int] = None, dim: int = 1, **kwargs: Any): super(MaskedMAE, self).__init__(metric_fn=F.l1_loss, @@ -47,19 +51,26 @@ class MaskedMAPE(MaskedMetric): """Mean Absolute Percentage Error Metric. Args: - mask_nans (bool, optional): Whether to automatically mask nan values. + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) dim (int): The index of the dimension that represents time in a batch. Relevant only when also 'at' is defined. - Default assumes [b t n f] format, hence is 1. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False - def __init__(self, mask_nans=False, at=None, dim: int = 1, **kwargs: Any): + def __init__(self, + mask_nans: bool = False, + at: Optional[int] = None, + dim: int = 1, + **kwargs: Any): super(MaskedMAPE, self).__init__(metric_fn=mape, mask_nans=mask_nans, @@ -74,14 +85,18 @@ class MaskedMSE(MaskedMetric): """Mean Squared Error Metric. Args: - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. + (default: :obj:`False`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) dim (int): The index of the dimension that represents time in a batch. Relevant only when also 'at' is defined. - Default assumes [b t n f] format, hence is 1. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True @@ -89,9 +104,9 @@ class MaskedMSE(MaskedMetric): full_state_update: bool = False def __init__(self, - mask_nans=False, - mask_inf=False, - at=None, + mask_nans: bool = False, + mask_inf: bool = False, + at: Optional[int] = None, dim: int = 1, **kwargs: Any): super(MaskedMSE, self).__init__(metric_fn=F.mse_loss, @@ -107,14 +122,18 @@ class MaskedMRE(MaskedMetric): """Mean Relative Error Metric. Args: - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. + (default: :obj:`False`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) dim (int): The index of the dimension that represents time in a batch. Relevant only when also 'at' is defined. - Default assumes [b t n f] format, hence is 1. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True @@ -122,9 +141,9 @@ class MaskedMRE(MaskedMetric): full_state_update: bool = False def __init__(self, - mask_nans=False, - mask_inf=False, - at=None, + mask_nans: bool = False, + mask_inf: bool = False, + at: Optional[int] = None, dim: int = 1, **kwargs: Any): super(MaskedMRE, self).__init__(metric_fn=F.l1_loss, diff --git a/tsl/metrics/torch/pinball_loss.py b/tsl/metrics/torch/pinball_loss.py index 1fb79db..45bd45d 100644 --- a/tsl/metrics/torch/pinball_loss.py +++ b/tsl/metrics/torch/pinball_loss.py @@ -1,3 +1,5 @@ +from typing import Any, Callable, Optional + from tsl.metrics.torch import pinball_loss from tsl.metrics.torch.metric_base import MaskedMetric @@ -7,19 +9,24 @@ class MaskedPinballLoss(MaskedMetric): Args: q (float): Target quantile. - mask_nans (bool, optional): Whether to automatically mask nan values. - mask_inf (bool, optional): Whether to automatically mask infinite + mask_nans (bool): Whether to automatically mask nan values. + (default: :obj:`False`) + mask_inf (bool): Whether to automatically mask infinite values. - compute_on_step (bool, optional): Whether to compute the metric + (default: :obj:`False`) + compute_on_step (bool): Whether to compute the metric right-away or if accumulate the results. This should be :obj:`True` when using the metric to compute a loss function, :obj:`False` if the metric is used for logging the aggregate error across different mini-batches. + (default: :obj:`True`) at (int, optional): Whether to compute the metric only w.r.t. a certain time step. + (default: :obj:`None`) dim (int): The index of the dimension that represents time in a batch. Relevant only when also 'at' is defined. - Default assumes [b t n f] format, hence is 1. + Default assumes [b t n f] format. + (default: :obj:`1`) """ is_differentiable: bool = True @@ -27,14 +34,14 @@ class MaskedPinballLoss(MaskedMetric): full_state_update: bool = False def __init__(self, - q, - mask_nans=False, - mask_inf=False, - compute_on_step=True, - dist_sync_on_step=False, - process_group=None, - dist_sync_fn=None, - at=None, + q: float, + mask_nans: bool = False, + mask_inf: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Any = None, + dist_sync_fn: Callable = None, + at: Optional[int] = None, dim: int = 1): super(MaskedPinballLoss, self).__init__(metric_fn=pinball_loss, From e7a0c2a3d967269f7f1ddd4febe6073e1e23be39 Mon Sep 17 00:00:00 2001 From: Luca Butera <22855332+LucaButera@users.noreply.github.com> Date: Thu, 22 Jun 2023 14:10:31 +0200 Subject: [PATCH 5/5] Implemented enhanced masked metric with slicing and multiple metric support --- tsl/metrics/torch/metric_base.py | 161 +++++++++++++++++++------------ tsl/typing.py | 4 +- tsl/utils/python_utils.py | 40 ++++++++ 3 files changed, 142 insertions(+), 63 deletions(-) diff --git a/tsl/metrics/torch/metric_base.py b/tsl/metrics/torch/metric_base.py index b58fe32..fe25de4 100644 --- a/tsl/metrics/torch/metric_base.py +++ b/tsl/metrics/torch/metric_base.py @@ -1,12 +1,15 @@ import inspect from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import torch from torchmetrics import Metric from torchmetrics.utilities.checks import _check_same_shape +from tsl.typing import Slicer +from tsl.utils.python_utils import parse_slicing_string + def convert_to_masked_metric(metric_fn, **kwargs): """ @@ -36,69 +39,115 @@ class MaskedMetric(Metric): In particular a `MaskedMetric` accounts for missing values in the input sequences by accepting a boolean mask as additional input. + Multiple metric functions can be specified, + in which case they will be averaged. + Weights can be assigned to perform a + weighted average of the different metrics. Args: - metric_fn (callable): Base function to compute the metric point-wise. + metric_fn (Sequence[callable], callable): + Base function to compute the metric + point-wise, multiple functions can be passed as a sequence. mask_nans (bool): Whether to automatically mask nan values. (default: :obj:`False`) mask_inf (bool): Whether to automatically mask infinite values. (default: :obj:`False`) - metric_fn_kwargs (dict, optional): Keyword arguments needed by - :obj:`metric_fn`. + metric_fn_kwargs (Sequence[dict], dict, optional): + Keyword arguments needed by :obj:`metric_fn`. + Use a sequence of keyword arguments if different :obj:`metric_fn` + require different arguments. + (default: :obj:`None`) + metric_fn_kwargs (Sequence[float], float, optional): + Weight assigned to each :obj:`metric_fn`. + Use a sequence if different :obj:`metric_fn` + require different weights. (default: :obj:`None`) - at (int, optional): Whether to compute the metric only w.r.t. a certain - time step. + at (str, Sequence[Tuple[Slicer, ...] | str], tuple[Slicer, ...], + Slicer, optional): + Numpy style slicing to define specific parts + of the output to compute the metrics on. + Either one for all metric or a sequence for each metric. + Slicing can either be a proper slicing tuple + or a string representation containing just + the part you would put inside square brackets + to index an array/tensor. (default: :obj:`None`) full_state_update (bool, optional): Set this to overwrite the :obj:`full_state_update` value of the :obj:`torchmetrics.Metric` base class. (default: :obj:`None`) - dim (int): The index of the dimension that represents time in a batch. - Relevant only when also 'at' is defined. - Default assumes [b t n f] format. - (default: :obj:`1`) """ is_differentiable: bool = None higher_is_better: bool = None full_state_update: bool = None - def __init__(self, - metric_fn: Callable, - mask_nans: bool = False, - mask_inf: bool = False, - metric_fn_kwargs: Optional[Dict[str, Any]] = None, - at: Optional[int] = None, - full_state_update: Optional[bool] = None, - dim: int = 1, - **kwargs: Any): - # set 'full_state_update' before Metric instantiation - if full_state_update is not None: - self.__dict__['full_state_update'] = full_state_update - super(MaskedMetric, self).__init__(**kwargs) - + def __init__( + self, + metric_fn: Union[Sequence[Callable], Callable], + metric_fn_kwargs: Optional[Union[Sequence[Dict[str, Any]], + Dict[str, Any]]] = None, + mask_nans: bool = False, + mask_inf: bool = False, + at: Union[str, Sequence[Union[Tuple[Slicer, ...], str]], + tuple[Slicer, ...], Slicer] = ..., + weights: Optional[Sequence[float]] = None, + full_state_update: Optional[bool] = None, + **kwargs: Any, + ): + super().__init__( + metric_fn=None, + mask_nans=mask_nans, + mask_inf=mask_inf, + metric_fn_kwargs=None, + at=None, + full_state_update=full_state_update, + **kwargs, + ) + assert ( + len({ + len(e) + for e in (metric_fn, metric_fn_kwargs, at, weights) + if isinstance(e, Sequence) + }) == 1 + ), "All sequences used as masked metric arguments " \ + "must have the same length." if metric_fn_kwargs is None: - metric_fn_kwargs = dict() - - self.metric_fn = partial(metric_fn, **metric_fn_kwargs) - + metric_fn_kwargs = {} + if isinstance(metric_fn, Sequence) and isinstance( + metric_fn_kwargs, Sequence): + self.metric_fn = tuple( + partial(fn, **fn_kwargs) + for fn, fn_kwargs in zip(metric_fn, metric_fn_kwargs)) + elif isinstance(metric_fn, Sequence): + self.metric_fn = tuple( + partial(fn, **metric_fn_kwargs) for fn in metric_fn) + else: + self.metric_fn = (partial(metric_fn, **metric_fn_kwargs), ) + if isinstance(at, str) or not isinstance(at, Sequence): + at = (at, ) + at = list( + parse_slicing_string(e) if isinstance(e, str) else e for e in at) + self.at = at * len(self.metric_fn) if len(at) == 1 else at + if weights is None: + self.weights = (1.0, ) * len(self.metric_fn) + else: + self.weights = weights self.mask_nans = mask_nans self.mask_inf = mask_inf - self.at = at - self.dim = dim - self.add_state('value', - dist_reduce_fx='sum', - default=torch.tensor(0., dtype=torch.float)) - self.add_state('numel', - dist_reduce_fx='sum', - default=torch.tensor(0., dtype=torch.float)) - - def _check_mask(self, mask, val): + self.add_state("value", + dist_reduce_fx="sum", + default=torch.tensor(0.0, dtype=torch.float)) + self.add_state("numel", + dist_reduce_fx="sum", + default=torch.tensor(0.0, dtype=torch.float)) + + def _check_mask(self, mask, val, at=...): if mask is None: mask = torch.ones_like(val, dtype=torch.bool) else: - mask = mask.bool() + mask = mask[at].bool() _check_same_shape(mask, val) if self.mask_nans: mask = mask & ~torch.isnan(val) @@ -106,33 +155,21 @@ def _check_mask(self, mask, val): mask = mask & ~torch.isinf(val) return mask - def _compute_masked(self, y_hat, y, mask): - _check_same_shape(y_hat, y) - val = self.metric_fn(y_hat, y) - mask = self._check_mask(mask, val) - val = torch.where(mask, val, torch.zeros_like(val)) - return val.sum(), mask.sum() - - def _compute_std(self, y_hat, y): - _check_same_shape(y_hat, y) - val = self.metric_fn(y_hat, y) - return val.sum(), val.numel() - def is_masked(self, mask): return self.mask_inf or self.mask_nans or (mask is not None) def update(self, y_hat, y, mask=None): - if self.at is not None: - y_hat = y_hat.select(self.dim, self.at) - y = y.select(self.dim, self.at) - if mask is not None: - mask = mask.select(self.dim, self.at) - if self.is_masked(mask): - val, numel = self._compute_masked(y_hat, y, mask) - else: - val, numel = self._compute_std(y_hat, y) - self.value += val - self.numel += numel + _check_same_shape(y_hat, y) + for i in range(len(self.metric_fn)): + val = self.metric_fn[i](y_hat[self.at[i]], y[self.at[i]]) + if self.is_masked(mask): + mask = self._check_mask(mask, val, self.at[i]) + val[~mask] = 0 + numel = mask.sum() + else: + numel = val.numel() + self.value += val.sum() * self.weights[i] + self.numel += numel def compute(self): if self.numel > 0: diff --git a/tsl/typing.py b/tsl/typing.py index 2a41a12..1e49f65 100644 --- a/tsl/typing.py +++ b/tsl/typing.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Literal, Optional, Tuple, Type, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union from numpy import ndarray from pandas import DataFrame, DatetimeIndex, PeriodIndex, TimedeltaIndex @@ -33,3 +33,5 @@ "linear"]] ModelReturnOptions = Type[Union[Tensor, Dict, List, Tuple]] + +Slicer = TypeVar("Slicer", slice, type(Ellipsis), int) diff --git a/tsl/utils/python_utils.py b/tsl/utils/python_utils.py index c8382cd..b1224c7 100644 --- a/tsl/utils/python_utils.py +++ b/tsl/utils/python_utils.py @@ -1,9 +1,12 @@ import inspect import os +import re from argparse import ArgumentParser from typing import (Any, Callable, List, Mapping, Optional, Sequence, Set, Type, Union) +import numpy as np + def ensure_list(value: Any) -> List: # if isinstance(value, Sequence) and not isinstance(value, str): @@ -129,3 +132,40 @@ def filter_kwargs(target: Union[Callable, Type], kwargs: Mapping): for k, v in kwargs.items() if k in signature['signature'] } return kwargs + + +def parse_slicing_element(e: str) -> type(Ellipsis) | slice | list[Any] | int: + """ + Parses single slicing elements. + + Args: + e: string representing the slicing element. + + Returns: + The parsed element. + """ + if e == "...": + return Ellipsis + elif ":" in e: + return slice(*(int(i) if not i == "" else None for i in e.split(":"))) + elif e.startswith("[") and e.endswith("]"): + return list(int(i) for i in e[1:-1].split(",")) + else: + return int(e) + + +def parse_slicing_string(s: str) -> tuple[int | slice | type(Ellipsis)]: + """ + Parses slicing elements obtained by splitting a string at each comma + considering elements inside square brackets as individual elements. + + Args: + s: string to parse. + + Returns: + A tuple containing the parsed elements. + """ + return np.index_exp[( + parse_slicing_element(e) + for e in re.split(r'\s*,\s*(?![^\[\]]*])', s.replace(" ", "")) + if not e == "")]