Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,15 +519,17 @@ def _update_single_label(

Comment thread
dario-coscia marked this conversation as resolved.
Outdated
:raises: ValueError: If the index type is not supported.
"""

old_dof = old_labels[to_update_dim]["dof"]
old_dof = old_labels[dim]["dof"]
label_name = old_labels[dim]["name"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we use old_labels and don't simply pass self.labels?

# Handle slicing
if isinstance(index, slice):
Comment thread
dario-coscia marked this conversation as resolved.
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
to_update_labels[to_update_dim] = {
"dof": old_dof[index],
"name": label_name,
}
# Handle single integer index
elif isinstance(index, int):
to_update_labels[dim] = {
to_update_labels[to_update_dim] = {
"dof": [old_dof[index]],
"name": label_name,
}
Expand All @@ -536,7 +538,7 @@ def _update_single_label(
# Handle list of bools
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
index = index.nonzero().squeeze()
to_update_labels[dim] = {
to_update_labels[to_update_dim] = {
"dof": (
[old_dof[i] for i in index]
if isinstance(old_dof, list)
Expand Down Expand Up @@ -589,10 +591,14 @@ def __getitem__(self, index):

# Update labels based on the index
offset = 0
removed = 0
for dim, idx in enumerate(index):
if dim in self.stored_labels:
if dim in original_labels:
if isinstance(idx, int):
selected_tensor = selected_tensor.unsqueeze(dim)
# Compute the working dimension considering the removed
# dimensions due to int index on a non labled dimension
dim_ = dim - removed
selected_tensor = selected_tensor.unsqueeze(dim_)
if idx != slice(None):
self._update_single_label(
original_labels, updated_labels, idx, dim, offset
Expand All @@ -605,6 +611,7 @@ def __getitem__(self, index):
key - 1 if key > dim else key: value
for key, value in updated_labels.items()
}
removed += 1
continue
offset += 1

Expand Down
60 changes: 60 additions & 0 deletions tests/test_label_tensor/test_label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,63 @@ def test_cat_bool(labels):
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
if isinstance(labels, dict):
assert selected.stored_labels[0]["dof"] == ["a", "b"]


def test_getitem_int():
data = torch.rand(20, 3)
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
lt = LabelTensor(data, labels)
new = lt[0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert torch.all(torch.isclose(data[0, 0], new))

data = torch.rand(20, 3, 2)
labels = {
1: {"name": 1, "dof": ["x", "y", "z"]},
2: {"name": 2, "dof": ["a", "b"]},
}
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, 0, 0], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a"]

new = lt[0, 0, :]
assert new.ndim == 2
assert new.shape[0] == 1
assert new.shape[1] == 2
assert torch.all(torch.isclose(data[0, 0, :], new))
assert new.stored_labels[0]["dof"] == ["x"]
assert new.stored_labels[1]["dof"] == ["a", "b"]

new = lt[0, :, 1]
assert new.ndim == 2
assert new.shape[0] == 3
assert new.shape[1] == 1
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
assert new.stored_labels[1]["dof"] == ["b"]

labels.pop(2)
lt = LabelTensor(data, labels)
new = lt[0, 0, 0]
assert new.ndim == 1
assert new.shape[0] == 1
assert new.stored_labels[0]["dof"] == ["x"]

new = lt[:, 0, 0]
assert new.ndim == 2
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.stored_labels[1]["dof"] == ["x"]

new = lt[:, 0, :]
assert new.ndim == 3
assert new.shape[0] == 20
assert new.shape[1] == 1
assert new.shape[2] == 2
assert new.stored_labels[1]["dof"] == ["x"]