Skip to content

Commit 452195a

Browse files
committed
Fix update labels function
1 parent 3eafc27 commit 452195a

1 file changed

Lines changed: 16 additions & 24 deletions

File tree

pina/label_tensor.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,7 @@ def vstack(tensors):
505505
return LabelTensor.cat(tensors, dim=0)
506506

507507
# This method is used to update labels
508-
def _update_single_label(
509-
self, old_labels, to_update_labels, index, dim, to_update_dim
510-
):
508+
def _update_single_label(self, index, dim):
511509
"""
512510
Update the labels of the tensor based on the index (or list of indices).
513511
@@ -519,38 +517,30 @@ def _update_single_label(
519517
520518
:raises: ValueError: If the index type is not supported.
521519
"""
522-
old_dof = old_labels[dim]["dof"]
523-
label_name = old_labels[dim]["name"]
520+
print(self._labels)
521+
old_dof = self._labels[dim]["dof"]
524522
# Handle slicing
525523
if isinstance(index, slice):
526-
to_update_labels[to_update_dim] = {
527-
"dof": old_dof[index],
528-
"name": label_name,
529-
}
524+
new_dof = old_dof[index]
530525
# Handle single integer index
531526
elif isinstance(index, int):
532-
to_update_labels[to_update_dim] = {
533-
"dof": [old_dof[index]],
534-
"name": label_name,
535-
}
527+
new_dof = [old_dof[index]]
536528
# Handle lists or tensors
537529
elif isinstance(index, (list, torch.Tensor)):
538530
# Handle list of bools
539531
if isinstance(index, torch.Tensor) and index.dtype == torch.bool:
540532
index = index.nonzero().squeeze()
541-
to_update_labels[to_update_dim] = {
542-
"dof": (
543-
[old_dof[i] for i in index]
544-
if isinstance(old_dof, list)
545-
else index
546-
),
547-
"name": label_name,
548-
}
533+
new_dof = (
534+
[old_dof[i] for i in index]
535+
if isinstance(old_dof, list)
536+
else index
537+
)
549538
else:
550539
raise NotImplementedError(
551540
f"Unsupported index type: {type(index)}. Expected slice, int, "
552541
f"list, or torch.Tensor."
553542
)
543+
return new_dof
554544

555545
def __getitem__(self, index):
556546
""" "
@@ -600,9 +590,11 @@ def __getitem__(self, index):
600590
dim_ = dim - removed
601591
selected_tensor = selected_tensor.unsqueeze(dim_)
602592
if idx != slice(None):
603-
self._update_single_label(
604-
original_labels, updated_labels, idx, dim, offset
605-
)
593+
# Update the labels for the selected dimension
594+
updated_labels[offset] = {
595+
"dof": self._update_single_label(idx, dim),
596+
"name": original_labels[dim]["name"],
597+
}
606598
else:
607599
# Adjust label keys if dimension is reduced (case of integer
608600
# index on a non-labeled dimension)

0 commit comments

Comments
 (0)