@@ -505,9 +505,7 @@ def vstack(tensors):
505505 return LabelTensor .cat (tensors , dim = 0 )
506506
507507 # This method is used to update labels
508- def _update_single_label (
509- self , old_labels , to_update_labels , index , dim , to_update_dim
510- ):
508+ def _update_single_label (self , index , dim ):
511509 """
512510 Update the labels of the tensor based on the index (or list of indices).
513511
@@ -519,38 +517,30 @@ def _update_single_label(
519517
520518 :raises: ValueError: If the index type is not supported.
521519 """
522- old_dof = old_labels [ dim ][ "dof" ]
523- label_name = old_labels [dim ]["name " ]
520+ print ( self . _labels )
521+ old_dof = self . _labels [dim ]["dof " ]
524522 # Handle slicing
525523 if isinstance (index , slice ):
526- to_update_labels [to_update_dim ] = {
527- "dof" : old_dof [index ],
528- "name" : label_name ,
529- }
524+ new_dof = old_dof [index ]
530525 # Handle single integer index
531526 elif isinstance (index , int ):
532- to_update_labels [to_update_dim ] = {
533- "dof" : [old_dof [index ]],
534- "name" : label_name ,
535- }
527+ new_dof = [old_dof [index ]]
536528 # Handle lists or tensors
537529 elif isinstance (index , (list , torch .Tensor )):
538530 # Handle list of bools
539531 if isinstance (index , torch .Tensor ) and index .dtype == torch .bool :
540532 index = index .nonzero ().squeeze ()
541- to_update_labels [to_update_dim ] = {
542- "dof" : (
543- [old_dof [i ] for i in index ]
544- if isinstance (old_dof , list )
545- else index
546- ),
547- "name" : label_name ,
548- }
533+ new_dof = (
534+ [old_dof [i ] for i in index ]
535+ if isinstance (old_dof , list )
536+ else index
537+ )
549538 else :
550539 raise NotImplementedError (
551540 f"Unsupported index type: { type (index )} . Expected slice, int, "
552541 f"list, or torch.Tensor."
553542 )
543+ return new_dof
554544
555545 def __getitem__ (self , index ):
556546 """ "
@@ -600,9 +590,11 @@ def __getitem__(self, index):
600590 dim_ = dim - removed
601591 selected_tensor = selected_tensor .unsqueeze (dim_ )
602592 if idx != slice (None ):
603- self ._update_single_label (
604- original_labels , updated_labels , idx , dim , offset
605- )
593+ # Update the labels for the selected dimension
594+ updated_labels [offset ] = {
595+ "dof" : self ._update_single_label (idx , dim ),
596+ "name" : original_labels [dim ]["name" ],
597+ }
606598 else :
607599 # Adjust label keys if dimension is reduced (case of integer
608600 # index on a non-labeled dimension)
0 commit comments