@@ -519,15 +519,17 @@ def _update_single_label(
519519
520520 :raises: ValueError: If the index type is not supported.
521521 """
522-
523- old_dof = old_labels [to_update_dim ]["dof" ]
522+ old_dof = old_labels [dim ]["dof" ]
524523 label_name = old_labels [dim ]["name" ]
525524 # Handle slicing
526525 if isinstance (index , slice ):
527- to_update_labels [dim ] = {"dof" : old_dof [index ], "name" : label_name }
526+ to_update_labels [to_update_dim ] = {
527+ "dof" : old_dof [index ],
528+ "name" : label_name ,
529+ }
528530 # Handle single integer index
529531 elif isinstance (index , int ):
530- to_update_labels [dim ] = {
532+ to_update_labels [to_update_dim ] = {
531533 "dof" : [old_dof [index ]],
532534 "name" : label_name ,
533535 }
@@ -536,7 +538,7 @@ def _update_single_label(
536538 # Handle list of bools
537539 if isinstance (index , torch .Tensor ) and index .dtype == torch .bool :
538540 index = index .nonzero ().squeeze ()
539- to_update_labels [dim ] = {
541+ to_update_labels [to_update_dim ] = {
540542 "dof" : (
541543 [old_dof [i ] for i in index ]
542544 if isinstance (old_dof , list )
@@ -589,10 +591,14 @@ def __getitem__(self, index):
589591
590592 # Update labels based on the index
591593 offset = 0
594+ removed = 0
592595 for dim , idx in enumerate (index ):
593- if dim in self . stored_labels :
596+ if dim in original_labels :
594597 if isinstance (idx , int ):
595- selected_tensor = selected_tensor .unsqueeze (dim )
598+ # Compute the working dimension considering the removed
599+ # dimensions due to int index on a non labled dimension
600+ dim_ = dim - removed
601+ selected_tensor = selected_tensor .unsqueeze (dim_ )
596602 if idx != slice (None ):
597603 self ._update_single_label (
598604 original_labels , updated_labels , idx , dim , offset
@@ -605,6 +611,7 @@ def __getitem__(self, index):
605611 key - 1 if key > dim else key : value
606612 for key , value in updated_labels .items ()
607613 }
614+ removed += 1
608615 continue
609616 offset += 1
610617
0 commit comments