Skip to content

Commit 3eafc27

Browse files
committed
Fix LabelTensor
1 parent 0a60ed4 commit 3eafc27

2 files changed

Lines changed: 74 additions & 7 deletions

File tree

pina/label_tensor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -519,15 +519,17 @@ def _update_single_label(
519519
520520
:raises: ValueError: If the index type is not supported.
521521
"""
522-
523-
old_dof = old_labels[to_update_dim]["dof"]
522+
old_dof = old_labels[dim]["dof"]
524523
label_name = old_labels[dim]["name"]
525524
# Handle slicing
526525
if isinstance(index, slice):
527-
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
526+
to_update_labels[to_update_dim] = {
527+
"dof": old_dof[index],
528+
"name": label_name,
529+
}
528530
# Handle single integer index
529531
elif isinstance(index, int):
530-
to_update_labels[dim] = {
532+
to_update_labels[to_update_dim] = {
531533
"dof": [old_dof[index]],
532534
"name": label_name,
533535
}
@@ -536,7 +538,7 @@ def _update_single_label(
536538
# Handle list of bools
537539
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
538540
index = index.nonzero().squeeze()
539-
to_update_labels[dim] = {
541+
to_update_labels[to_update_dim] = {
540542
"dof": (
541543
[old_dof[i] for i in index]
542544
if isinstance(old_dof, list)
@@ -589,10 +591,14 @@ def __getitem__(self, index):
589591

590592
# Update labels based on the index
591593
offset = 0
594+
removed = 0
592595
for dim, idx in enumerate(index):
593-
if dim in self.stored_labels:
596+
if dim in original_labels:
594597
if isinstance(idx, int):
595-
selected_tensor = selected_tensor.unsqueeze(dim)
598+
# Compute the working dimension considering the removed
599+
# dimensions due to int index on a non labled dimension
600+
dim_ = dim - removed
601+
selected_tensor = selected_tensor.unsqueeze(dim_)
596602
if idx != slice(None):
597603
self._update_single_label(
598604
original_labels, updated_labels, idx, dim, offset
@@ -605,6 +611,7 @@ def __getitem__(self, index):
605611
key - 1 if key > dim else key: value
606612
for key, value in updated_labels.items()
607613
}
614+
removed += 1
608615
continue
609616
offset += 1
610617

tests/test_label_tensor/test_label_tensor.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,63 @@ def test_cat_bool(labels):
278278
assert selected.stored_labels[1]["dof"] == [f"s{i}" for i in range(10)]
279279
if isinstance(labels, dict):
280280
assert selected.stored_labels[0]["dof"] == ["a", "b"]
281+
282+
283+
def test_getitem_int():
284+
data = torch.rand(20, 3)
285+
labels = {1: {"name": 1, "dof": ["x", "y", "z"]}}
286+
lt = LabelTensor(data, labels)
287+
new = lt[0, 0]
288+
assert new.ndim == 1
289+
assert new.shape[0] == 1
290+
assert torch.all(torch.isclose(data[0, 0], new))
291+
292+
data = torch.rand(20, 3, 2)
293+
labels = {
294+
1: {"name": 1, "dof": ["x", "y", "z"]},
295+
2: {"name": 2, "dof": ["a", "b"]},
296+
}
297+
lt = LabelTensor(data, labels)
298+
new = lt[0, 0, 0]
299+
assert new.ndim == 2
300+
assert new.shape[0] == 1
301+
assert new.shape[1] == 1
302+
assert torch.all(torch.isclose(data[0, 0, 0], new))
303+
assert new.stored_labels[0]["dof"] == ["x"]
304+
assert new.stored_labels[1]["dof"] == ["a"]
305+
306+
new = lt[0, 0, :]
307+
assert new.ndim == 2
308+
assert new.shape[0] == 1
309+
assert new.shape[1] == 2
310+
assert torch.all(torch.isclose(data[0, 0, :], new))
311+
assert new.stored_labels[0]["dof"] == ["x"]
312+
assert new.stored_labels[1]["dof"] == ["a", "b"]
313+
314+
new = lt[0, :, 1]
315+
assert new.ndim == 2
316+
assert new.shape[0] == 3
317+
assert new.shape[1] == 1
318+
assert torch.all(torch.isclose(data[0, :, 1], new.squeeze()))
319+
assert new.stored_labels[0]["dof"] == ["x", "y", "z"]
320+
assert new.stored_labels[1]["dof"] == ["b"]
321+
322+
labels.pop(2)
323+
lt = LabelTensor(data, labels)
324+
new = lt[0, 0, 0]
325+
assert new.ndim == 1
326+
assert new.shape[0] == 1
327+
assert new.stored_labels[0]["dof"] == ["x"]
328+
329+
new = lt[:, 0, 0]
330+
assert new.ndim == 2
331+
assert new.shape[0] == 20
332+
assert new.shape[1] == 1
333+
assert new.stored_labels[1]["dof"] == ["x"]
334+
335+
new = lt[:, 0, :]
336+
assert new.ndim == 3
337+
assert new.shape[0] == 20
338+
assert new.shape[1] == 1
339+
assert new.shape[2] == 2
340+
assert new.stored_labels[1]["dof"] == ["x"]

0 commit comments

Comments
 (0)