Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 85 additions & 67 deletions robodm/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ class VLADataset:
4. Efficient data management for large datasets
"""

def __init__(self,
path: Text,
mode: Union[str, LoadingMode] = LoadingMode.TRAJECTORY,
split: str = "all",
return_type: str = "numpy",
config: Optional[DatasetConfig] = None,
slice_config: Optional[SliceConfig] = None,
**kwargs):
def __init__(
self,
path: Text,
mode: Union[str, LoadingMode] = LoadingMode.TRAJECTORY,
split: str = "all",
return_type: str = "numpy",
config: Optional[DatasetConfig] = None,
slice_config: Optional[SliceConfig] = None,
**kwargs,
):
"""
Initialize VLA dataset.

Expand Down Expand Up @@ -85,37 +87,44 @@ def __init__(self,
shuffle=self.config.shuffle,
num_parallel_reads=self.config.num_parallel_reads,
slice_config=slice_config,
**kwargs)
**kwargs,
)

# Cache for schema and stats
self._schema = None
self._stats = None

@classmethod
def create_trajectory_dataset(cls,
path: Text,
split: str = "all",
return_type: str = "numpy",
config: Optional[DatasetConfig] = None,
**kwargs) -> "VLADataset":
def create_trajectory_dataset(
cls,
path: Text,
split: str = "all",
return_type: str = "numpy",
config: Optional[DatasetConfig] = None,
**kwargs,
) -> "VLADataset":
"""Create a dataset for loading complete trajectories."""
return cls(path=path,
mode=LoadingMode.TRAJECTORY,
return_type=return_type,
config=config,
**kwargs)
return cls(
path=path,
mode=LoadingMode.TRAJECTORY,
return_type=return_type,
config=config,
**kwargs,
)

@classmethod
def create_slice_dataset(cls,
path: Text,
slice_length: int = 100,
return_type: str = "numpy",
config: Optional[DatasetConfig] = None,
min_slice_length: Optional[int] = None,
stride: int = 1,
random_start: bool = True,
overlap_ratio: float = 0.0,
**kwargs) -> "VLADataset":
def create_slice_dataset(
cls,
path: Text,
slice_length: int = 100,
return_type: str = "numpy",
config: Optional[DatasetConfig] = None,
min_slice_length: Optional[int] = None,
stride: int = 1,
random_start: bool = True,
overlap_ratio: float = 0.0,
**kwargs,
) -> "VLADataset":
"""Create a dataset for loading trajectory slices."""
slice_config = SliceConfig(
slice_length=slice_length,
Expand All @@ -125,12 +134,14 @@ def create_slice_dataset(cls,
overlap_ratio=overlap_ratio,
)

return cls(path=path,
mode=LoadingMode.SLICE,
return_type=return_type,
config=config,
slice_config=slice_config,
**kwargs)
return cls(
path=path,
mode=LoadingMode.SLICE,
return_type=return_type,
config=config,
slice_config=slice_config,
**kwargs,
)

def get_ray_dataset(self) -> rd.Dataset:
"""Get the underlying Ray dataset."""
Expand Down Expand Up @@ -245,7 +256,7 @@ def get_stats(self) -> Dict[str, Any]:
"total_items":
self.count(),
"sample_keys":
list(sample.keys()) if isinstance(sample, dict) else [],
(list(sample.keys()) if isinstance(sample, dict) else []),
}

# Add mode-specific stats
Expand All @@ -260,8 +271,9 @@ def get_stats(self) -> Dict[str, Any]:
first_key = next(iter(sample.keys())) if sample else None
if first_key and hasattr(sample[first_key], "__len__"):
self._stats["slice_length"] = len(sample[first_key])
self._stats[
"slice_start"] = 0 # Cannot determine from direct data
self._stats["slice_start"] = (
0 # Cannot determine from direct data
)
self._stats["slice_end"] = len(sample[first_key])
else:
self._stats = {"mode": self.mode.value, "total_items": 0}
Expand Down Expand Up @@ -313,13 +325,15 @@ def get_next_trajectory(self):


# Utility functions for common dataset operations
def load_trajectory_dataset(path: Text,
split: str = "all",
return_type: str = "numpy",
batch_size: int = 1,
shuffle: bool = False,
num_parallel_reads: int = 4,
**kwargs) -> VLADataset:
def load_trajectory_dataset(
path: Text,
split: str = "all",
return_type: str = "numpy",
batch_size: int = 1,
shuffle: bool = False,
num_parallel_reads: int = 4,
**kwargs,
) -> VLADataset:
"""Load a dataset for complete trajectories."""
config = DatasetConfig(batch_size=batch_size,
shuffle=shuffle,
Expand All @@ -330,31 +344,35 @@ def load_trajectory_dataset(path: Text,
**kwargs)


def load_slice_dataset(path: Text,
slice_length: int = 100,
split: str = "all",
return_type: str = "numpy",
batch_size: int = 1,
shuffle: bool = False,
num_parallel_reads: int = 4,
min_slice_length: Optional[int] = None,
stride: int = 1,
random_start: bool = True,
overlap_ratio: float = 0.0,
**kwargs) -> VLADataset:
def load_slice_dataset(
path: Text,
slice_length: int = 100,
split: str = "all",
return_type: str = "numpy",
batch_size: int = 1,
shuffle: bool = False,
num_parallel_reads: int = 4,
min_slice_length: Optional[int] = None,
stride: int = 1,
random_start: bool = True,
overlap_ratio: float = 0.0,
**kwargs,
) -> VLADataset:
"""Load a dataset for trajectory slices."""
config = DatasetConfig(batch_size=batch_size,
shuffle=shuffle,
num_parallel_reads=num_parallel_reads)
return VLADataset.create_slice_dataset(path=path,
slice_length=slice_length,
return_type=return_type,
config=config,
min_slice_length=min_slice_length,
stride=stride,
random_start=random_start,
overlap_ratio=overlap_ratio,
**kwargs)
return VLADataset.create_slice_dataset(
path=path,
slice_length=slice_length,
return_type=return_type,
config=config,
min_slice_length=min_slice_length,
stride=stride,
random_start=random_start,
overlap_ratio=overlap_ratio,
**kwargs,
)


def split_dataset(
Expand Down
5 changes: 3 additions & 2 deletions robodm/loader/vla.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ class SliceConfig:
"""Configuration for slice loading mode."""

slice_length: int = 100 # Number of timesteps per slice
min_slice_length: Optional[
int] = None # Minimum slice length (defaults to slice_length)
min_slice_length: Optional[int] = (
None # Minimum slice length (defaults to slice_length)
)
stride: int = 1 # Stride between consecutive timesteps in slice
random_start: bool = True # Whether to randomly sample start position
overlap_ratio: float = 0.0 # Overlap ratio between consecutive slices (0.0-1.0)
Expand Down
Loading
Loading