Skip to content

Commit 0606ff3

Browse files
committed
Fix LabelTensor
1 parent 0a60ed4 commit 0606ff3

2 files changed

Lines changed: 14 additions & 7 deletions

File tree

pina/.DS_Store

10 KB
Binary file not shown.

pina/label_tensor.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -519,15 +519,17 @@ def _update_single_label(
519519
520520
:raises: ValueError: If the index type is not supported.
521521
"""
522-
523-
old_dof = old_labels[to_update_dim]["dof"]
522+
old_dof = old_labels[dim]["dof"]
524523
label_name = old_labels[dim]["name"]
525524
# Handle slicing
526525
if isinstance(index, slice):
527-
to_update_labels[dim] = {"dof": old_dof[index], "name": label_name}
526+
to_update_labels[to_update_dim] = {
527+
"dof": old_dof[index],
528+
"name": label_name,
529+
}
528530
# Handle single integer index
529531
elif isinstance(index, int):
530-
to_update_labels[dim] = {
532+
to_update_labels[to_update_dim] = {
531533
"dof": [old_dof[index]],
532534
"name": label_name,
533535
}
@@ -536,7 +538,7 @@ def _update_single_label(
536538
# Handle list of bools
537539
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
538540
index = index.nonzero().squeeze()
539-
to_update_labels[dim] = {
541+
to_update_labels[to_update_dim] = {
540542
"dof": (
541543
[old_dof[i] for i in index]
542544
if isinstance(old_dof, list)
@@ -589,10 +591,14 @@ def __getitem__(self, index):
589591

590592
# Update labels based on the index
591593
offset = 0
594+
removed = 0
592595
for dim, idx in enumerate(index):
593-
if dim in self.stored_labels:
596+
if dim in original_labels:
594597
if isinstance(idx, int):
595-
selected_tensor = selected_tensor.unsqueeze(dim)
598+
# Compute the working dimension considering the removed
599+
# dimensions due to int index on a non labled dimension
600+
dim_ = dim - removed
601+
selected_tensor = selected_tensor.unsqueeze(dim_)
596602
if idx != slice(None):
597603
self._update_single_label(
598604
original_labels, updated_labels, idx, dim, offset
@@ -605,6 +611,7 @@ def __getitem__(self, index):
605611
key - 1 if key > dim else key: value
606612
for key, value in updated_labels.items()
607613
}
614+
removed += 1
608615
continue
609616
offset += 1
610617

0 commit comments

Comments
 (0)