Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 28 additions & 12 deletions pina/_src/core/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from copy import copy, deepcopy
import torch
from torch import Tensor


class LabelTensor(torch.Tensor):
Expand All @@ -25,7 +24,6 @@ def __new__(cls, x, labels, *args, **kwargs):
class.
:rtype: LabelTensor
"""

if isinstance(x, LabelTensor):
return x
return super().__new__(cls, x, *args, **kwargs)
Expand All @@ -39,8 +37,7 @@ def tensor(self):
:return: Tensor part of the :class:`~pina.label_tensor.LabelTensor`.
:rtype: torch.Tensor
"""

return self.as_subclass(Tensor)
return self.as_subclass(torch.Tensor)

def __init__(self, x, labels):
"""
Expand Down Expand Up @@ -366,25 +363,44 @@ def cat(tensors, dim=0):
return cat_tensor

@staticmethod
def stack(tensors):
def stack(tensors, dim=0):
"""
Stacks a list of tensors along a new dimension. For more details, see
:meth:`torch.stack`.

:param list[LabelTensor] tensors: A list of tensors to stack.
All tensors must have the same shape.
:param list[LabelTensor] tensors: The list of tensors to stack. All
tensors must have the same shape and labels.
:param int dim: The dimension along which to insert the new stacked
dimension. It follows torch.stack semantics. The new dimension
cannot be inserted in a dimension that already has labels.
:return: A new :class:`~pina.label_tensor.LabelTensor` instance obtained
by stacking the input tensors.
:rtype: LabelTensor
"""
# Ensure all tensors are LabelTensor instances
if not all(isinstance(tensor, LabelTensor) for tensor in tensors):
raise TypeError("All tensors must be LabelTensor instances.")

# Perform stacking in torch
new_tensor = torch.stack(tensors)
# Ensure all tensors have the same labels
reference_labels = tensors[0].labels
for tensor in tensors:
if tensor.labels != reference_labels:
raise ValueError("All tensors must have the same labels.")

# Increase labels keys by 1
labels = tensors[0]._labels
labels = {key + 1: value for key, value in labels.items()}
# Avoid stacking along the labels dimension
if dim in tensors[0]._labels:
raise ValueError(
f"Cannot stack along labels dimension {dim}. "
"Choose a lower dimension index."
)

# Stack tensors
new_tensor = torch.stack(tensors, dim=dim)

# Reassign labels to the new tensor
labels = {key + 1: value for key, value in tensors[0]._labels.items()}
new_tensor._labels = labels

return new_tensor

def requires_grad_(self, mode=True):
Expand Down
Loading