@@ -19,7 +19,7 @@ def __new__(cls, x, labels, *args, **kwargs):
1919 :param torch.Tensor x: :class:`torch.tensor` instance to be casted as a
2020 :class:`LabelTensor`.
2121 :param labels: Labels to assign to the tensor.
22- :type labels: str | list( str) | dict
22+ :type labels: str | list[ str] | dict
2323 :return: The instance of the :class:`LabelTensor` class.
2424 :rtype: LabelTensor
2525 """
@@ -115,7 +115,7 @@ def labels(self, labels):
115115 - **str**: The string is assigned to the last dimension.
116116
117117 :param labels: Labels to assign to the class variable _labels.
118- :type labels: str | list( str) | dict
118+ :type labels: str | list[ str] | dict
119119 """
120120
121121 if not hasattr (self , "_labels" ):
@@ -205,11 +205,11 @@ def extract(self, labels_to_extract):
205205 a list of labels is passed, the last dimension is considered.
206206 If a dictionary is passed, the keys are the dimension names and the
207207 values are the labels to extract.
208- :type labels_to_extract: str | list( str) | tuple( str) | dict
208+ :type labels_to_extract: str | list[ str] | tuple[ str] | dict
209209 :return: The extracted tensor with the updated labels.
210210 :rtype: LabelTensor
211211
212- :raises TypeError: Labels are not ``str``, ``list( str) `` or ``dict``
212+ :raises TypeError: Labels are not ``str``, ``list of str`` or ``dict``
213213 properly setted.
214214 :raises ValueError: Label to extract is not in the labels ``list``.
215215 """
@@ -290,7 +290,7 @@ def cat(tensors, dim=0):
290290 Concatenate a list of tensors along a specified dimension. For more
291291 details, see :meth:`torch.cat`.
292292
293- :param list( LabelTensor) tensors: :class:`LabelTensor` instances to
293+ :param list[ LabelTensor] tensors: :class:`LabelTensor` instances to
294294 concatenate
295295 :param int dim: dimensions on which you want to perform the operation
296296 (default is 0)
@@ -344,7 +344,7 @@ def stack(tensors):
344344 Stacks a list of tensors along a new dimension. For more details, see
345345 :meth:`torch.stack`.
346346
347- :param list( LabelTensor) tensors: A list of tensors to stack.
347+ :param list[ LabelTensor] tensors: A list of tensors to stack.
348348 All tensors must have the same shape.
349349 :return: A new :class:`LabelTensor` instance obtained by stacking the
350350 input tensors, with the updated labels.
@@ -466,7 +466,7 @@ def vstack(label_tensors):
466466 """
467467 Stack tensors vertically. For more details, see :meth:`torch.vstack`.
468468
469- :param list( LabelTensor) label_tensors: The :class:`LabelTensor`
469+ :param list of LabelTensor label_tensors: The :class:`LabelTensor`
470470 instances to stack. They need to have equal labels.
471471 :return: A new :class:`LabelTensor` instance obtained by stacking the
472472 input tensors vertically.
@@ -485,7 +485,7 @@ def _update_single_label(
485485 :param dict old_labels: Labels from which retrieve data.
486486 :param dict to_update_labels: Labels to update.
487487 :param index: Index of dof to retain.
488- :type index: int | slice | list | torch.Tensor]
488+ :type index: int | slice | list[int] | tuple[int] | torch.Tensor
489489 :param int dim: The dimension to update.
490490
491491 :raises: ValueError: If the index type is not supported.
@@ -529,7 +529,7 @@ def __getitem__(self, index):
529529 labels based on the index.
530530
531531 :param index: The index used to access the item
532- :type index: int | str | tuple | list | torch.Tensor
532+ :type index: int | str | tuple of int | list ot int | torch.Tensor
533533 :return: A new :class:`LabelTensor` instance obtained __getitem__
534534 operation on :class:`torch.Tensor` part of the instance, with the
535535 updated labels.
@@ -630,7 +630,7 @@ def permute(self, *dims):
630630 accordingly. For more details, see :meth:`torch.Tensor.permute`.
631631
632632 :param dims: The dimensions to permute the tensor to.
633- :type dims: tuple, list
633+ :type dims: tuple[int] | list[int]
634634 :return: A new object with permuted dimensions and reordered labels.
635635 :rtype: LabelTensor
636636 """
@@ -668,8 +668,8 @@ def summation(tensors):
668668 Computes the summation of a list of :class:`LabelTensor` instances.
669669
670670
671- :param list( LabelTensor) tensors: A list of tensors to sum. All tensors
672- must have the same shape and labels.
671+ :param list[ LabelTensor] tensors: A list of tensors to sum. All
672+ tensors must have the same shape and labels.
673673 :return: A new `LabelTensor` containing the element-wise sum of the
674674 input tensors.
675675 :rtype: LabelTensor
@@ -711,7 +711,7 @@ def reshape(self, *shape):
711711 Override the reshape method to update the labels of the tensor.
712712 For more details, see :meth:`torch.Tensor.reshape`.
713713
714- :param tuple shape: The new shape of the tensor.
714+ :param tuple of int shape: The new shape of the tensor.
715715 :return: A new :class:`LabelTensor` instance with the updated shape and
716716 labels.
717717 :rtype: LabelTensor
0 commit comments