Skip to content

Commit 2ce5d53

Browse files
authored
Merge pull request #31 from BerkeleyAutomation/fix/image-resolution
fixing image resolution (#29)
2 parents ee80f2a + 9fe3797 commit 2ce5d53

5 files changed

Lines changed: 489 additions & 134 deletions

File tree

robodm/dataset.py

Lines changed: 85 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ class VLADataset:
3838
4. Efficient data management for large datasets
3939
"""
4040

41-
def __init__(self,
42-
path: Text,
43-
mode: Union[str, LoadingMode] = LoadingMode.TRAJECTORY,
44-
split: str = "all",
45-
return_type: str = "numpy",
46-
config: Optional[DatasetConfig] = None,
47-
slice_config: Optional[SliceConfig] = None,
48-
**kwargs):
41+
def __init__(
42+
self,
43+
path: Text,
44+
mode: Union[str, LoadingMode] = LoadingMode.TRAJECTORY,
45+
split: str = "all",
46+
return_type: str = "numpy",
47+
config: Optional[DatasetConfig] = None,
48+
slice_config: Optional[SliceConfig] = None,
49+
**kwargs,
50+
):
4951
"""
5052
Initialize VLA dataset.
5153
@@ -85,37 +87,44 @@ def __init__(self,
8587
shuffle=self.config.shuffle,
8688
num_parallel_reads=self.config.num_parallel_reads,
8789
slice_config=slice_config,
88-
**kwargs)
90+
**kwargs,
91+
)
8992

9093
# Cache for schema and stats
9194
self._schema = None
9295
self._stats = None
9396

9497
@classmethod
95-
def create_trajectory_dataset(cls,
96-
path: Text,
97-
split: str = "all",
98-
return_type: str = "numpy",
99-
config: Optional[DatasetConfig] = None,
100-
**kwargs) -> "VLADataset":
98+
def create_trajectory_dataset(
99+
cls,
100+
path: Text,
101+
split: str = "all",
102+
return_type: str = "numpy",
103+
config: Optional[DatasetConfig] = None,
104+
**kwargs,
105+
) -> "VLADataset":
101106
"""Create a dataset for loading complete trajectories."""
102-
return cls(path=path,
103-
mode=LoadingMode.TRAJECTORY,
104-
return_type=return_type,
105-
config=config,
106-
**kwargs)
107+
return cls(
108+
path=path,
109+
mode=LoadingMode.TRAJECTORY,
110+
return_type=return_type,
111+
config=config,
112+
**kwargs,
113+
)
107114

108115
@classmethod
109-
def create_slice_dataset(cls,
110-
path: Text,
111-
slice_length: int = 100,
112-
return_type: str = "numpy",
113-
config: Optional[DatasetConfig] = None,
114-
min_slice_length: Optional[int] = None,
115-
stride: int = 1,
116-
random_start: bool = True,
117-
overlap_ratio: float = 0.0,
118-
**kwargs) -> "VLADataset":
116+
def create_slice_dataset(
117+
cls,
118+
path: Text,
119+
slice_length: int = 100,
120+
return_type: str = "numpy",
121+
config: Optional[DatasetConfig] = None,
122+
min_slice_length: Optional[int] = None,
123+
stride: int = 1,
124+
random_start: bool = True,
125+
overlap_ratio: float = 0.0,
126+
**kwargs,
127+
) -> "VLADataset":
119128
"""Create a dataset for loading trajectory slices."""
120129
slice_config = SliceConfig(
121130
slice_length=slice_length,
@@ -125,12 +134,14 @@ def create_slice_dataset(cls,
125134
overlap_ratio=overlap_ratio,
126135
)
127136

128-
return cls(path=path,
129-
mode=LoadingMode.SLICE,
130-
return_type=return_type,
131-
config=config,
132-
slice_config=slice_config,
133-
**kwargs)
137+
return cls(
138+
path=path,
139+
mode=LoadingMode.SLICE,
140+
return_type=return_type,
141+
config=config,
142+
slice_config=slice_config,
143+
**kwargs,
144+
)
134145

135146
def get_ray_dataset(self) -> rd.Dataset:
136147
"""Get the underlying Ray dataset."""
@@ -245,7 +256,7 @@ def get_stats(self) -> Dict[str, Any]:
245256
"total_items":
246257
self.count(),
247258
"sample_keys":
248-
list(sample.keys()) if isinstance(sample, dict) else [],
259+
(list(sample.keys()) if isinstance(sample, dict) else []),
249260
}
250261

251262
# Add mode-specific stats
@@ -260,8 +271,9 @@ def get_stats(self) -> Dict[str, Any]:
260271
first_key = next(iter(sample.keys())) if sample else None
261272
if first_key and hasattr(sample[first_key], "__len__"):
262273
self._stats["slice_length"] = len(sample[first_key])
263-
self._stats[
264-
"slice_start"] = 0 # Cannot determine from direct data
274+
self._stats["slice_start"] = (
275+
0 # Cannot determine from direct data
276+
)
265277
self._stats["slice_end"] = len(sample[first_key])
266278
else:
267279
self._stats = {"mode": self.mode.value, "total_items": 0}
@@ -313,13 +325,15 @@ def get_next_trajectory(self):
313325

314326

315327
# Utility functions for common dataset operations
316-
def load_trajectory_dataset(path: Text,
317-
split: str = "all",
318-
return_type: str = "numpy",
319-
batch_size: int = 1,
320-
shuffle: bool = False,
321-
num_parallel_reads: int = 4,
322-
**kwargs) -> VLADataset:
328+
def load_trajectory_dataset(
329+
path: Text,
330+
split: str = "all",
331+
return_type: str = "numpy",
332+
batch_size: int = 1,
333+
shuffle: bool = False,
334+
num_parallel_reads: int = 4,
335+
**kwargs,
336+
) -> VLADataset:
323337
"""Load a dataset for complete trajectories."""
324338
config = DatasetConfig(batch_size=batch_size,
325339
shuffle=shuffle,
@@ -330,31 +344,35 @@ def load_trajectory_dataset(path: Text,
330344
**kwargs)
331345

332346

333-
def load_slice_dataset(path: Text,
334-
slice_length: int = 100,
335-
split: str = "all",
336-
return_type: str = "numpy",
337-
batch_size: int = 1,
338-
shuffle: bool = False,
339-
num_parallel_reads: int = 4,
340-
min_slice_length: Optional[int] = None,
341-
stride: int = 1,
342-
random_start: bool = True,
343-
overlap_ratio: float = 0.0,
344-
**kwargs) -> VLADataset:
347+
def load_slice_dataset(
348+
path: Text,
349+
slice_length: int = 100,
350+
split: str = "all",
351+
return_type: str = "numpy",
352+
batch_size: int = 1,
353+
shuffle: bool = False,
354+
num_parallel_reads: int = 4,
355+
min_slice_length: Optional[int] = None,
356+
stride: int = 1,
357+
random_start: bool = True,
358+
overlap_ratio: float = 0.0,
359+
**kwargs,
360+
) -> VLADataset:
345361
"""Load a dataset for trajectory slices."""
346362
config = DatasetConfig(batch_size=batch_size,
347363
shuffle=shuffle,
348364
num_parallel_reads=num_parallel_reads)
349-
return VLADataset.create_slice_dataset(path=path,
350-
slice_length=slice_length,
351-
return_type=return_type,
352-
config=config,
353-
min_slice_length=min_slice_length,
354-
stride=stride,
355-
random_start=random_start,
356-
overlap_ratio=overlap_ratio,
357-
**kwargs)
365+
return VLADataset.create_slice_dataset(
366+
path=path,
367+
slice_length=slice_length,
368+
return_type=return_type,
369+
config=config,
370+
min_slice_length=min_slice_length,
371+
stride=stride,
372+
random_start=random_start,
373+
overlap_ratio=overlap_ratio,
374+
**kwargs,
375+
)
358376

359377

360378
def split_dataset(

robodm/loader/vla.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ class SliceConfig:
3434
"""Configuration for slice loading mode."""
3535

3636
slice_length: int = 100 # Number of timesteps per slice
37-
min_slice_length: Optional[
38-
int] = None # Minimum slice length (defaults to slice_length)
37+
min_slice_length: Optional[int] = (
38+
None # Minimum slice length (defaults to slice_length)
39+
)
3940
stride: int = 1 # Stride between consecutive timesteps in slice
4041
random_start: bool = True # Whether to randomly sample start position
4142
overlap_ratio: float = 0.0 # Overlap ratio between consecutive slices (0.0-1.0)

0 commit comments

Comments
 (0)