diff --git a/pina/_src/core/label_tensor.py b/pina/_src/core/label_tensor.py index 41bccc6fc..9a81fb3a6 100644 --- a/pina/_src/core/label_tensor.py +++ b/pina/_src/core/label_tensor.py @@ -2,7 +2,6 @@ from copy import copy, deepcopy import torch -from torch import Tensor class LabelTensor(torch.Tensor): @@ -25,7 +24,6 @@ def __new__(cls, x, labels, *args, **kwargs): class. :rtype: LabelTensor """ - if isinstance(x, LabelTensor): return x return super().__new__(cls, x, *args, **kwargs) @@ -39,8 +37,7 @@ def tensor(self): :return: Tensor part of the :class:`~pina.label_tensor.LabelTensor`. :rtype: torch.Tensor """ - - return self.as_subclass(Tensor) + return self.as_subclass(torch.Tensor) def __init__(self, x, labels): """ @@ -366,25 +363,44 @@ def cat(tensors, dim=0): return cat_tensor @staticmethod - def stack(tensors): + def stack(tensors, dim=0): """ Stacks a list of tensors along a new dimension. For more details, see :meth:`torch.stack`. - :param list[LabelTensor] tensors: A list of tensors to stack. - All tensors must have the same shape. + :param list[LabelTensor] tensors: The list of tensors to stack. All + tensors must have the same shape and labels. + :param int dim: The dimension along which to insert the new stacked + dimension. It follows torch.stack semantics. The new dimension + cannot be inserted in a dimension that already has labels. :return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained by stacking the input tensors. :rtype: LabelTensor """ + # Ensure all tensors are LabelTensor instances + if not all(isinstance(tensor, LabelTensor) for tensor in tensors): + raise TypeError("All tensors must be LabelTensor instances.") - # Perform stacking in torch - new_tensor = torch.stack(tensors) + # Ensure all tensors have the same labels + reference_labels = tensors[0].labels + for tensor in tensors: + if tensor.labels != reference_labels: + raise ValueError("All tensors must have the same labels.") - # Increase labels keys by 1 - labels = tensors[0]._labels - labels = {key + 1: value for key, value in labels.items()} + # Avoid stacking along the labels dimension + if dim in tensors[0]._labels: + raise ValueError( + f"Cannot stack along labels dimension {dim}. " + "Choose a lower dimension index." + ) + + # Stack tensors + new_tensor = torch.stack(tensors, dim=dim) + + # Reassign labels to the new tensor + labels = {key + 1: value for key, value in tensors[0]._labels.items()} new_tensor._labels = labels + return new_tensor def requires_grad_(self, mode=True):