@@ -589,21 +589,28 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]:
589589 ret [kk ] = data [kk ]
590590 return ret , idx
591591
592- def _get_nframes (self , set_name : DPPath ) -> int :
593- # get nframes
592+ def _get_nframes (self , set_name : DPPath | str ) -> int :
594593 if not isinstance (set_name , DPPath ):
595594 set_name = DPPath (set_name )
596- path = set_name / "coord.npy"
597- # Read only the header to get shape
598- with open (str (path ), "rb" ) as f :
599- version = np .lib .format .read_magic (f )
600- if version [0 ] == 1 :
601- shape , fortran_order , dtype = np .lib .format .read_array_header_1_0 (f )
602- elif version [0 ] in [2 , 3 ]:
603- shape , fortran_order , dtype = np .lib .format .read_array_header_2_0 (f )
604- else :
605- raise ValueError (f"Unsupported .npy file version: { version } " )
606- nframes = shape [0 ] if (len (shape ) if isinstance (shape , tuple ) else 0 ) > 1 else 1
595+ if isinstance (set_name , DPH5Path ):
596+ path = set_name / "coord.npy"
597+ nframes = path .root [path ._name ].shape [0 ]
598+ else :
599+ path = set_name / "coord.npy"
600+ # Read only the header to get shape
601+ with open (str (path ), "rb" ) as f :
602+ version = np .lib .format .read_magic (f )
603+ if version [0 ] == 1 :
604+ shape , _fortran_order , _dtype = np .lib .format .read_array_header_1_0 (
605+ f
606+ )
607+ elif version [0 ] in [2 , 3 ]:
608+ shape , _fortran_order , _dtype = np .lib .format .read_array_header_2_0 (
609+ f
610+ )
611+ else :
612+ raise ValueError (f"Unsupported .npy file version: { version } " )
613+ nframes = shape [0 ] if len (shape ) > 1 else 1
607614 return nframes
608615
609616 def reformat_data_torch (self , data : dict [str , Any ]) -> dict [str , Any ]:
0 commit comments