Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 25 additions & 28 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,50 +505,40 @@ def vstack(tensors):
return LabelTensor.cat(tensors, dim=0)

# This method is used to update labels
def _update_single_label(
self, old_labels, to_update_labels, index, dim, to_update_dim
):
def _update_single_label(self, index, dim):
"""
Update the labels of the tensor based on the index (or list of indices).

:param dict old_labels: Labels from which retrieve data.
:param dict to_update_labels: Labels to update.
:param index: Index of dof to retain.
:type index: int | slice | list[int] | tuple[int] | torch.Tensor
:param int dim: The dimension to update.

:param int dim: Dimension of the indexes in the original tensor.
:return: The updated labels for the specified dimension.
:rtype: list[int]
:raises: ValueError: If the index type is not supported.
"""

old_dof = old_labels[to_update_dim]["dof"]
label_name = old_labels[dim]["name"]
old_dof = self._labels[dim]["dof"]
# Handle slicing
if isinstance(index, slice):
Comment thread
dario-coscia marked this conversation as resolved.
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
new_dof = old_dof[index]
# Handle single integer index
elif isinstance(index, int):
to_update_labels[dim] = {
"dof": [old_dof[index]],
"name": label_name,
}
new_dof = [old_dof[index]]
# Handle lists or tensors
elif isinstance(index, (list, torch.Tensor)):
# Handle list of bools
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero().squeeze()
to_update_labels[dim] = {
"dof": (
[old_dof[i] for i in index]
if isinstance(old_dof, list)
else index
),
"name": label_name,
}
new_dof = (
[old_dof[i] for i in index]
if isinstance(old_dof, list)
else index
)
else:
raise NotImplementedError(
f"Unsupported index type: {type(index)}. Expected slice, int, "
f"list, or torch.Tensor."
)
return new_dof

def __getitem__(self, index):
""" "
Expand Down Expand Up @@ -589,14 +579,20 @@ def __getitem__(self, index):

# Update labels based on the index
offset = 0
removed = 0
for dim, idx in enumerate(index):
if dim in self.stored_labels:
if dim in original_labels:
if isinstance(idx, int):
selected_tensor = selected_tensor.unsqueeze(dim)
# Compute the working dimension considering the removed
# dimensions due to int index on a non labled dimension
dim_ = dim - removed
selected_tensor = selected_tensor.unsqueeze(dim_)
if idx != slice(None):
self._update_single_label(
original_labels, updated_labels, idx, dim, offset
)
# Update the labels for the selected dimension
updated_labels[offset] = {
"dof": self._update_single_label(idx, dim),
"name": original_labels[dim]["name"],
}
else:
# Adjust label keys if dimension is reduced (case of integer
# index on a non-labeled dimension)
Expand All @@ -605,6 +601,7 @@ def __getitem__(self, index):
key - 1 if key > dim else key: value
for key, value in updated_labels.items()
}
removed += 1
continue
offset += 1

Expand Down
60 changes: 60 additions & 0 deletions tests/test_label_tensor/test_label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,63 @@ def test_cat_bool(labels):
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
if isinstance(labels, dict):
assert selected.stored_labels[0]["dof"] == ["a", "b"]


def test_getitem_int():
data = torch.rand(20, 3)
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
lt = LabelTensor(data, labels)
new = lt[0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert torch.all(torch.isclose(data[0, 0], new))

data = torch.rand(20, 3, 2)
labels = {
1: {"name": 1, "dof": ["x", "y", "z"]},
2: {"name": 2, "dof": ["a", "b"]},
}
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, 0, 0], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a"]

new = lt[0, 0, :]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 2
assert torch.all(torch.isclose(data[0, 0, :], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a", "b"]

new = lt[0, :, 1]
assert new.ndim == 2
assert new.shape[0] == 3
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
assert new.stored_labels[1]["dof"] == ["b"]

labels.pop(2)
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert new.stored_labels[0]["dof"] == ["x"]

new = lt[:, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.stored_labels[1]["dof"] == ["x"]

new = lt[:, 0, :]
assert new.ndim == 3
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.shape[2] == 2
assert new.stored_labels[1]["dof"] == ["x"]
Loading