@@ -38,14 +38,16 @@ class VLADataset:
3838 4. Efficient data management for large datasets
3939 """
4040
41- def __init__ (self ,
42- path : Text ,
43- mode : Union [str , LoadingMode ] = LoadingMode .TRAJECTORY ,
44- split : str = "all" ,
45- return_type : str = "numpy" ,
46- config : Optional [DatasetConfig ] = None ,
47- slice_config : Optional [SliceConfig ] = None ,
48- ** kwargs ):
41+ def __init__ (
42+ self ,
43+ path : Text ,
44+ mode : Union [str , LoadingMode ] = LoadingMode .TRAJECTORY ,
45+ split : str = "all" ,
46+ return_type : str = "numpy" ,
47+ config : Optional [DatasetConfig ] = None ,
48+ slice_config : Optional [SliceConfig ] = None ,
49+ ** kwargs ,
50+ ):
4951 """
5052 Initialize VLA dataset.
5153
@@ -85,37 +87,44 @@ def __init__(self,
8587 shuffle = self .config .shuffle ,
8688 num_parallel_reads = self .config .num_parallel_reads ,
8789 slice_config = slice_config ,
88- ** kwargs )
90+ ** kwargs ,
91+ )
8992
9093 # Cache for schema and stats
9194 self ._schema = None
9295 self ._stats = None
9396
9497 @classmethod
95- def create_trajectory_dataset (cls ,
96- path : Text ,
97- split : str = "all" ,
98- return_type : str = "numpy" ,
99- config : Optional [DatasetConfig ] = None ,
100- ** kwargs ) -> "VLADataset" :
98+ def create_trajectory_dataset (
99+ cls ,
100+ path : Text ,
101+ split : str = "all" ,
102+ return_type : str = "numpy" ,
103+ config : Optional [DatasetConfig ] = None ,
104+ ** kwargs ,
105+ ) -> "VLADataset" :
101106 """Create a dataset for loading complete trajectories."""
102- return cls (path = path ,
103- mode = LoadingMode .TRAJECTORY ,
104- return_type = return_type ,
105- config = config ,
106- ** kwargs )
107+ return cls (
108+ path = path ,
109+ mode = LoadingMode .TRAJECTORY ,
110+ return_type = return_type ,
111+ config = config ,
112+ ** kwargs ,
113+ )
107114
108115 @classmethod
109- def create_slice_dataset (cls ,
110- path : Text ,
111- slice_length : int = 100 ,
112- return_type : str = "numpy" ,
113- config : Optional [DatasetConfig ] = None ,
114- min_slice_length : Optional [int ] = None ,
115- stride : int = 1 ,
116- random_start : bool = True ,
117- overlap_ratio : float = 0.0 ,
118- ** kwargs ) -> "VLADataset" :
116+ def create_slice_dataset (
117+ cls ,
118+ path : Text ,
119+ slice_length : int = 100 ,
120+ return_type : str = "numpy" ,
121+ config : Optional [DatasetConfig ] = None ,
122+ min_slice_length : Optional [int ] = None ,
123+ stride : int = 1 ,
124+ random_start : bool = True ,
125+ overlap_ratio : float = 0.0 ,
126+ ** kwargs ,
127+ ) -> "VLADataset" :
119128 """Create a dataset for loading trajectory slices."""
120129 slice_config = SliceConfig (
121130 slice_length = slice_length ,
@@ -125,12 +134,14 @@ def create_slice_dataset(cls,
125134 overlap_ratio = overlap_ratio ,
126135 )
127136
128- return cls (path = path ,
129- mode = LoadingMode .SLICE ,
130- return_type = return_type ,
131- config = config ,
132- slice_config = slice_config ,
133- ** kwargs )
137+ return cls (
138+ path = path ,
139+ mode = LoadingMode .SLICE ,
140+ return_type = return_type ,
141+ config = config ,
142+ slice_config = slice_config ,
143+ ** kwargs ,
144+ )
134145
135146 def get_ray_dataset (self ) -> rd .Dataset :
136147 """Get the underlying Ray dataset."""
@@ -245,7 +256,7 @@ def get_stats(self) -> Dict[str, Any]:
245256 "total_items" :
246257 self .count (),
247258 "sample_keys" :
248- list (sample .keys ()) if isinstance (sample , dict ) else [],
259+ ( list (sample .keys ()) if isinstance (sample , dict ) else []) ,
249260 }
250261
251262 # Add mode-specific stats
@@ -260,8 +271,9 @@ def get_stats(self) -> Dict[str, Any]:
260271 first_key = next (iter (sample .keys ())) if sample else None
261272 if first_key and hasattr (sample [first_key ], "__len__" ):
262273 self ._stats ["slice_length" ] = len (sample [first_key ])
263- self ._stats [
264- "slice_start" ] = 0 # Cannot determine from direct data
274+ self ._stats ["slice_start" ] = (
275+ 0 # Cannot determine from direct data
276+ )
265277 self ._stats ["slice_end" ] = len (sample [first_key ])
266278 else :
267279 self ._stats = {"mode" : self .mode .value , "total_items" : 0 }
@@ -313,13 +325,15 @@ def get_next_trajectory(self):
313325
314326
315327# Utility functions for common dataset operations
316- def load_trajectory_dataset (path : Text ,
317- split : str = "all" ,
318- return_type : str = "numpy" ,
319- batch_size : int = 1 ,
320- shuffle : bool = False ,
321- num_parallel_reads : int = 4 ,
322- ** kwargs ) -> VLADataset :
328+ def load_trajectory_dataset (
329+ path : Text ,
330+ split : str = "all" ,
331+ return_type : str = "numpy" ,
332+ batch_size : int = 1 ,
333+ shuffle : bool = False ,
334+ num_parallel_reads : int = 4 ,
335+ ** kwargs ,
336+ ) -> VLADataset :
323337 """Load a dataset for complete trajectories."""
324338 config = DatasetConfig (batch_size = batch_size ,
325339 shuffle = shuffle ,
@@ -330,31 +344,35 @@ def load_trajectory_dataset(path: Text,
330344 ** kwargs )
331345
332346
333- def load_slice_dataset (path : Text ,
334- slice_length : int = 100 ,
335- split : str = "all" ,
336- return_type : str = "numpy" ,
337- batch_size : int = 1 ,
338- shuffle : bool = False ,
339- num_parallel_reads : int = 4 ,
340- min_slice_length : Optional [int ] = None ,
341- stride : int = 1 ,
342- random_start : bool = True ,
343- overlap_ratio : float = 0.0 ,
344- ** kwargs ) -> VLADataset :
347+ def load_slice_dataset (
348+ path : Text ,
349+ slice_length : int = 100 ,
350+ split : str = "all" ,
351+ return_type : str = "numpy" ,
352+ batch_size : int = 1 ,
353+ shuffle : bool = False ,
354+ num_parallel_reads : int = 4 ,
355+ min_slice_length : Optional [int ] = None ,
356+ stride : int = 1 ,
357+ random_start : bool = True ,
358+ overlap_ratio : float = 0.0 ,
359+ ** kwargs ,
360+ ) -> VLADataset :
345361 """Load a dataset for trajectory slices."""
346362 config = DatasetConfig (batch_size = batch_size ,
347363 shuffle = shuffle ,
348364 num_parallel_reads = num_parallel_reads )
349- return VLADataset .create_slice_dataset (path = path ,
350- slice_length = slice_length ,
351- return_type = return_type ,
352- config = config ,
353- min_slice_length = min_slice_length ,
354- stride = stride ,
355- random_start = random_start ,
356- overlap_ratio = overlap_ratio ,
357- ** kwargs )
365+ return VLADataset .create_slice_dataset (
366+ path = path ,
367+ slice_length = slice_length ,
368+ return_type = return_type ,
369+ config = config ,
370+ min_slice_length = min_slice_length ,
371+ stride = stride ,
372+ random_start = random_start ,
373+ overlap_ratio = overlap_ratio ,
374+ ** kwargs ,
375+ )
358376
359377
360378def split_dataset (
0 commit comments