From 5c231b3fab98065ccf6a11ed308ed69d87f4a77b Mon Sep 17 00:00:00 2001 From: Eric Kaiyuan Chen Date: Thu, 5 Jun 2025 22:49:13 -0700 Subject: [PATCH 1/2] fix --- robodm/trajectory.py | 197 +++++++++++++++++++++++--------- tests/test_shape_codec_logic.py | 182 +++++++++++++++++++++++++++++ 2 files changed, 328 insertions(+), 51 deletions(-) create mode 100644 tests/test_shape_codec_logic.py diff --git a/robodm/trajectory.py b/robodm/trajectory.py index 0d4e3b5..8365f3b 100644 --- a/robodm/trajectory.py +++ b/robodm/trajectory.py @@ -8,7 +8,7 @@ from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone from fractions import Fraction -from typing import Any, Dict, List, Optional, Text, Union, cast +from typing import Any, Dict, List, Optional, Text, Tuple, Union, cast import av import h5py @@ -316,6 +316,62 @@ def __repr__(self): class CodecConfig: """Configuration class for video codec settings.""" + @staticmethod + def get_supported_pixel_formats(codec_name: str) -> List[str]: + """Get list of supported pixel formats for a codec.""" + try: + import av + codec = av.codec.Codec(codec_name, "w") + if codec.video_formats: + return [vf.name for vf in codec.video_formats] + return [] + except Exception: + return [] + + @staticmethod + def is_codec_config_supported(width: int, height: int, pix_fmt: str = "yuv420p", codec_name: str = "libx264") -> bool: + """Check if a specific width/height/pixel format combination is supported by codec.""" + try: + import av + from fractions import Fraction + + cc = av.codec.CodecContext.create(codec_name, "w") + cc.width = width + cc.height = height + cc.pix_fmt = pix_fmt + cc.time_base = Fraction(1, 30) + cc.open(strict=True) + cc.close() + return True + except Exception: + return False + + @staticmethod + def is_valid_image_shape(shape: Tuple[int, ...], codec_name: str = "libx264") -> bool: + """Check if a shape can be treated as an RGB image for the given codec.""" + # Only accept RGB shapes (H, W, 3) + if len(shape) != 3 or shape[2] != 3: + return False + + height, width = shape[0], shape[1] + + # Check minimum reasonable image size + if height < 1 or width < 1: + return False + + # Check codec-specific constraints + if codec_name in ["libx264", "libx265"]: + # H.264/H.265 require even dimensions + if height % 2 != 0 or width % 2 != 0: + return False + elif codec_name in ["libaom-av1"]: + # AV1 also typically requires even dimensions for yuv420p + if height % 2 != 0 or width % 2 != 0: + return False + + # Test if the codec actually supports this resolution + return CodecConfig.is_codec_config_supported(width, height, "yuv420p", codec_name) + # Default codec configurations CODEC_CONFIGS = { "rawvideo": { @@ -371,17 +427,38 @@ def __init__(self, def get_codec_for_feature(self, feature_type: FeatureType) -> str: """Determine the appropriate codec for a given feature type.""" - # Auto-selection logic based on feature characteristics data_shape = feature_type.shape - if (data_shape is not None and len(data_shape) >= 2 - and data_shape[0] >= 100 and data_shape[1] >= 100): - # Large images - use efficient video codec + + # Only use video codecs for RGB images (H, W, 3) + if (data_shape is not None and len(data_shape) == 3 and data_shape[2] == 3): + height, width = data_shape[0], data_shape[1] + + # If user specified a codec other than auto, try to use it for RGB images if self.codec != "auto": - return self.codec - return "libaom-av1" # Default to AV1 for large images + if self.is_valid_image_shape(data_shape, self.codec): + logger.debug(f"Using user-specified codec {self.codec} for RGB shape {data_shape}") + return self.codec + else: + logger.warning(f"User-specified codec {self.codec} doesn't support shape {data_shape}, falling back to rawvideo") + return "rawvideo" + + # Auto-selection for RGB images only + codec_preferences = ["libaom-av1", "ffv1", "libx264", "libx265"] + + for codec in codec_preferences: + if self.is_valid_image_shape(data_shape, codec): + logger.debug(f"Selected codec {codec} for RGB shape {data_shape}") + return codec + + # If no video codec works for this RGB image, fall back to rawvideo + logger.warning(f"No video codec supports RGB shape {data_shape}, falling back to rawvideo") + else: - # Small data or non-image data - use rawvideo - return "rawvideo" + # Non-RGB data (grayscale, depth, vectors, etc.) always use rawvideo + if data_shape is not None: + logger.debug(f"Using rawvideo for non-RGB shape {data_shape}") + + return "rawvideo" def get_pixel_format(self, codec: str, feature_type: FeatureType) -> Optional[str]: @@ -394,19 +471,16 @@ def get_pixel_format(self, codec: str, if base_format is None: # rawvideo case return None - # Adjust pixel format based on feature type + # Only use RGB formats for actual RGB data (H, W, 3) shape = feature_type.shape if shape is not None and len(shape) == 3 and shape[2] == 3: - # RGB image + # RGB data - use appropriate RGB format return ("yuv420p" if codec in [ "libx264", "libx265", "libaom-av1", "ffv1" ] else "rgb24") - elif shape is not None and (len(shape) == 2 or - (len(shape) == 3 and shape[2] == 1)): - # Grayscale image - return "gray" else: - return base_format + # Non-RGB data should not get video pixel formats + return None def get_codec_options(self, codec: str) -> Dict[str, Any]: """Get codec options, merging defaults with custom options.""" @@ -954,15 +1028,18 @@ def want(idx: int) -> bool: else: for frame in packet.decode(): ft = self.feature_name_to_feature_type[fname] - if ft.dtype == "float32": - arr = frame.to_ndarray( - format="gray") # depth / float32 - if ft.shape: - arr = arr.reshape(ft.shape) + # Only decode as RGB24 for RGB data, otherwise this shouldn't happen + # since non-RGB data should use rawvideo + if ft.shape and len(ft.shape) == 3 and ft.shape[2] == 3: + # RGB data - decode as RGB24 + arr = frame.to_ndarray(format="rgb24") else: + # This shouldn't happen with our new logic, but handle gracefully + logger.warning(f"Non-RGB data {fname} with shape {ft.shape} using video codec - this may cause issues") arr = frame.to_ndarray(format="rgb24") - if ft.shape: - arr = arr.reshape(ft.shape) + + if ft.shape: + arr = arr.reshape(ft.shape) cache[fname].append(arr) decoded_packets += 1 logger.debug( @@ -999,10 +1076,15 @@ def want(idx: int) -> bool: continue ft = self.feature_name_to_feature_type[fname] - if ft.dtype == "float32": - arr = frame.to_ndarray(format="gray") + # Only decode as RGB24 for RGB data + if ft.shape and len(ft.shape) == 3 and ft.shape[2] == 3: + # RGB data - decode as RGB24 + arr = frame.to_ndarray(format="rgb24") else: + # This shouldn't happen with our new logic, but handle gracefully + logger.warning(f"Non-RGB data {fname} with shape {ft.shape} using video codec - this may cause issues") arr = frame.to_ndarray(format="rgb24") + if ft.shape: arr = arr.reshape(ft.shape) cache[fname].append(arr) @@ -1094,6 +1176,8 @@ def add( if type(data) == dict: raise ValueError("Use add_by_dict for dictionary") + + feature_type = FeatureType.from_data(data) # encoding = self._get_encoding_of_feature(data, None) self.feature_name_to_feature_type[feature] = feature_type @@ -1334,16 +1418,19 @@ def _load_from_container(self): else: frames = packet.decode() for frame in frames: - if feature_type.dtype == "float32": - shape = feature_type.shape + # Only decode as RGB24 for RGB data + shape = feature_type.shape + if shape and len(shape) == 3 and shape[2] == 3: + # RGB data - decode as RGB24 if shape is not None: data = frame.to_ndarray( # type: ignore[attr-defined] - format="gray").reshape(shape) + format="rgb24").reshape(shape) else: data = frame.to_ndarray( - format="gray") # type: ignore[attr-defined] + format="rgb24") # type: ignore[attr-defined] else: - shape = feature_type.shape + # This shouldn't happen with our new logic, but handle gracefully + logger.warning(f"Non-RGB data {feature_name} with shape {shape} using video codec") if shape is not None: data = frame.to_ndarray( # type: ignore[attr-defined] format="rgb24").reshape(shape) @@ -1538,10 +1625,8 @@ def _encode_frame(self, data: Any, stream: Any, if (encoding in ["ffv1", "libaom-av1", "libx264", "libx265"] and shape is not None and len(shape) >= 2): logger.debug("Using video encoding path for image-like data") - if feature_type.dtype == "float32": - frame = self._create_frame_depth(data, stream) - else: - frame = self._create_frame(data, stream) + # Always use RGB frame creation, no special handling for float32 + frame = self._create_frame(data, stream) frame.pts = timestamp frame.dts = timestamp frame.time_base = stream.time_base @@ -1685,28 +1770,38 @@ def _add_stream_to_container(self, container, feature_name, encoding, return stream def _create_frame(self, image_array, stream): - image_array = np.array(image_array, dtype=np.uint8) + image_array = np.array(image_array) encoding = stream.codec_context.codec.name - # Determine the correct format based on array shape and codec - if len(image_array.shape) == 3 and image_array.shape[2] == 3: - # RGB image - if encoding in ["libaom-av1", "ffv1", "libx264", "libx265"]: - # For video codecs that prefer YUV, convert RGB to YUV420p - frame = av.VideoFrame.from_ndarray(image_array, format="rgb24") - frame = frame.reformat(format="yuv420p") + # Convert to uint8 if needed + if image_array.dtype == np.float32: + # Assume float32 values are in [0, 1] range, scale to [0, 255] + image_array = np.clip(image_array * 255, 0, 255).astype(np.uint8) + elif image_array.dtype != np.uint8: + # Convert other dtypes to uint8 + if np.issubdtype(image_array.dtype, np.integer): + # For integer types, clamp to 0-255 range + image_array = np.clip(image_array, 0, 255).astype(np.uint8) else: - frame = av.VideoFrame.from_ndarray(image_array, format="rgb24") - elif len(image_array.shape) == 3 and image_array.shape[2] == 1: - # Single channel image, squeeze the last dimension - frame = av.VideoFrame.from_ndarray(image_array.squeeze(axis=2), - format="gray") - elif len(image_array.shape) == 2: - # Grayscale image - frame = av.VideoFrame.from_ndarray(image_array, format="gray") + # For other types, normalize and convert + image_array = np.clip(image_array * 255, 0, 255).astype(np.uint8) + + # Only handle RGB images (HxWx3) - no grayscale conversion + if len(image_array.shape) == 3 and image_array.shape[2] == 3: + # RGB image - proceed with video encoding + pass else: raise ValueError( - f"Unsupported image array shape: {image_array.shape}") + f"Video codecs only support RGB images with shape (H, W, 3). " + f"Got shape {image_array.shape}. Use rawvideo encoding for other formats.") + + # Create RGB frame + if encoding in ["libaom-av1", "ffv1", "libx264", "libx265"]: + # For video codecs that prefer YUV, convert RGB to YUV420p + frame = av.VideoFrame.from_ndarray(image_array, format="rgb24") + frame = frame.reformat(format="yuv420p") + else: + frame = av.VideoFrame.from_ndarray(image_array, format="rgb24") frame.time_base = stream.time_base return frame diff --git a/tests/test_shape_codec_logic.py b/tests/test_shape_codec_logic.py new file mode 100644 index 0000000..fa10c18 --- /dev/null +++ b/tests/test_shape_codec_logic.py @@ -0,0 +1,182 @@ +"""Test cases for shape-based codec selection and dimensionality checking.""" + +import os +import tempfile + +import numpy as np +import pytest + +from robodm import FeatureType, Trajectory +from robodm.trajectory import CodecConfig + + +class TestShapeBasedCodecSelection: + """Test codec selection based on data shape.""" + + def test_rgb_image_codec_selection(self): + """Test that RGB images get video codecs when compatible.""" + config = CodecConfig() + + # RGB image with even dimensions should get a video codec + rgb_even = FeatureType(dtype="uint8", shape=(128, 128, 3)) + codec = config.get_codec_for_feature(rgb_even) + assert codec != "rawvideo", f"RGB image with even dimensions should get video codec, got {codec}" + assert codec in ["libx264", "libx265", "libaom-av1", "ffv1"], f"Got unexpected codec: {codec}" + + def test_non_rgb_shapes_use_rawvideo(self): + """Test that non-RGB shapes always use rawvideo.""" + config = CodecConfig() + + test_cases = [ + ((128, 128), "Grayscale image"), + ((10,), "1D vector"), + ((5, 10), "2D matrix"), + ((128, 128, 1), "Single channel image"), + ((128, 128, 4), "RGBA image"), + ((20, 30, 5), "Multi-channel data"), + ] + + for shape, description in test_cases: + feature_type = FeatureType(dtype="float32", shape=shape) + codec = config.get_codec_for_feature(feature_type) + assert codec == "rawvideo", f"{description} should use rawvideo, got {codec}" + + def test_user_specified_codec_validation(self): + """Test user-specified codec validation for RGB images.""" + # Valid user-specified codec for compatible RGB image + config = CodecConfig(codec="libx264") + rgb_even = FeatureType(dtype="uint8", shape=(128, 128, 3)) + codec = config.get_codec_for_feature(rgb_even) + assert codec == "libx264", f"Compatible RGB should use user-specified codec, got {codec}" + + # Invalid user-specified codec for incompatible RGB image + config = CodecConfig(codec="libx264") + rgb_odd = FeatureType(dtype="uint8", shape=(127, 129, 3)) + codec = config.get_codec_for_feature(rgb_odd) + assert codec == "rawvideo", f"Incompatible RGB should fall back to rawvideo, got {codec}" + + +class TestCodecCompatibilityValidation: + """Test codec compatibility validation methods.""" + + def test_is_valid_image_shape(self): + """Test the is_valid_image_shape method.""" + test_cases = [ + # (shape, codec, expected_result, description) + ((128, 128, 3), "libx264", True, "Even dimensions should work"), + ((127, 129, 3), "libx264", False, "Odd dimensions should fail for H.264"), + ((1920, 1080, 3), "libx264", True, "Large even dimensions should work"), + ((2, 2, 3), "libx264", True, "Very small even dimensions might work"), + ((128, 128), "libx264", False, "Non-RGB should not be valid for video codec"), + ((10,), "libx264", False, "1D data should not be valid"), + ] + + for shape, codec, expected, description in test_cases: + result = CodecConfig.is_valid_image_shape(shape, codec) + assert result == expected, f"{description}: shape {shape} with {codec} expected {expected}, got {result}" + + def test_is_codec_config_supported(self): + """Test PyAV codec configuration support.""" + # These should work for most systems + assert CodecConfig.is_codec_config_supported(128, 128, "yuv420p", "libx264") + + # Very large dimensions might not work + large_result = CodecConfig.is_codec_config_supported(10000, 10000, "yuv420p", "libx264") + # Don't assert this as it depends on system capabilities + print(f"Large dimensions test result: {large_result}") + + +class TestRoundtripData: + """Test roundtrip encoding/decoding for various data shapes.""" + + def test_different_shapes_and_types(self): + """Test that different data shapes and types can be handled.""" + config = CodecConfig() + + test_cases = [ + # (shape, dtype, expected_codec_type) + ((128, 128, 3), "uint8", "video"), # RGB image + ((100, 200, 3), "uint8", "video"), # Different RGB size + ((128, 128), "uint8", "rawvideo"), # Grayscale + ((10,), "float32", "rawvideo"), # Vector + ((5, 10), "float64", "rawvideo"), # Matrix + ((128, 128, 1), "uint8", "rawvideo"), # Single channel + ((128, 128, 4), "uint8", "rawvideo"), # RGBA + ] + + for shape, dtype, expected_type in test_cases: + feature_type = FeatureType(dtype=dtype, shape=shape) + codec = config.get_codec_for_feature(feature_type) + + if expected_type == "video": + assert codec != "rawvideo", f"Shape {shape} should get video codec, got {codec}" + else: + assert codec == "rawvideo", f"Shape {shape} should get rawvideo, got {codec}" + + def test_mixed_rgb_and_non_rgb_in_trajectory(self): + """Test handling mixed RGB and non-RGB data types.""" + config = CodecConfig() + + # Simulate mixed data in a trajectory + features = { + "camera/rgb": FeatureType(dtype="uint8", shape=(128, 128, 3)), # RGB + "camera/depth": FeatureType(dtype="float32", shape=(128, 128)), # Depth + "robot/joint_pos": FeatureType(dtype="float32", shape=(7,)), # Vector + "camera/mask": FeatureType(dtype="uint8", shape=(128, 128, 1)), # Mask + } + + codecs = {} + for name, feature_type in features.items(): + codecs[name] = config.get_codec_for_feature(feature_type) + + # Only RGB should get video codec + assert codecs["camera/rgb"] != "rawvideo", "RGB should get video codec" + assert codecs["camera/depth"] == "rawvideo", "Depth should get rawvideo" + assert codecs["robot/joint_pos"] == "rawvideo", "Joint positions should get rawvideo" + assert codecs["camera/mask"] == "rawvideo", "Mask should get rawvideo" + + +class TestPixelFormatSelection: + """Test pixel format selection logic.""" + + def test_rgb_pixel_format_selection(self): + """Test pixel format selection for RGB data.""" + config = CodecConfig() + + rgb_type = FeatureType(dtype="uint8", shape=(128, 128, 3)) + + # Test different codecs + yuv_codecs = ["libx264", "libx265", "libaom-av1", "ffv1"] + for codec in yuv_codecs: + result = config.get_pixel_format(codec, rgb_type) + assert result == "yuv420p", f"RGB data with {codec} should get yuv420p, got {result}" + + def test_non_rgb_pixel_format_selection(self): + """Test pixel format selection for non-RGB data.""" + config = CodecConfig() + + # Non-RGB data should not get RGB pixel formats + grayscale_type = FeatureType(dtype="uint8", shape=(128, 128)) + vector_type = FeatureType(dtype="float32", shape=(10,)) + + # These should return None (no pixel format for non-RGB) + for data_type in [grayscale_type, vector_type]: + for codec in ["libx264", "libx265", "libaom-av1", "ffv1"]: + result = config.get_pixel_format(codec, data_type) + # Should not return RGB-specific formats + assert result is None, f"Non-RGB data should not get pixel format, got {result}" + + def test_rawvideo_pixel_format(self): + """Test that rawvideo returns None for pixel format.""" + config = CodecConfig() + + rgb_type = FeatureType(dtype="uint8", shape=(128, 128, 3)) + result = config.get_pixel_format("rawvideo", rgb_type) + assert result is None, f"rawvideo should return None for pixel format, got {result}" + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir \ No newline at end of file From 9fe379791e9ba607e5400178313754bc5f28d380 Mon Sep 17 00:00:00 2001 From: Eric Kaiyuan Chen Date: Thu, 5 Jun 2025 22:49:46 -0700 Subject: [PATCH 2/2] format --- robodm/dataset.py | 152 ++++++++++++++++++-------------- robodm/loader/vla.py | 5 +- robodm/trajectory.py | 99 +++++++++++++-------- tests/test_openx_trajectory.py | 12 +-- tests/test_shape_codec_logic.py | 130 +++++++++++++++++---------- 5 files changed, 238 insertions(+), 160 deletions(-) diff --git a/robodm/dataset.py b/robodm/dataset.py index 4f5f03a..4c297ae 100644 --- a/robodm/dataset.py +++ b/robodm/dataset.py @@ -38,14 +38,16 @@ class VLADataset: 4. Efficient data management for large datasets """ - def __init__(self, - path: Text, - mode: Union[str, LoadingMode] = LoadingMode.TRAJECTORY, - split: str = "all", - return_type: str = "numpy", - config: Optional[DatasetConfig] = None, - slice_config: Optional[SliceConfig] = None, - **kwargs): + def __init__( + self, + path: Text, + mode: Union[str, LoadingMode] = LoadingMode.TRAJECTORY, + split: str = "all", + return_type: str = "numpy", + config: Optional[DatasetConfig] = None, + slice_config: Optional[SliceConfig] = None, + **kwargs, + ): """ Initialize VLA dataset. @@ -85,37 +87,44 @@ def __init__(self, shuffle=self.config.shuffle, num_parallel_reads=self.config.num_parallel_reads, slice_config=slice_config, - **kwargs) + **kwargs, + ) # Cache for schema and stats self._schema = None self._stats = None @classmethod - def create_trajectory_dataset(cls, - path: Text, - split: str = "all", - return_type: str = "numpy", - config: Optional[DatasetConfig] = None, - **kwargs) -> "VLADataset": + def create_trajectory_dataset( + cls, + path: Text, + split: str = "all", + return_type: str = "numpy", + config: Optional[DatasetConfig] = None, + **kwargs, + ) -> "VLADataset": """Create a dataset for loading complete trajectories.""" - return cls(path=path, - mode=LoadingMode.TRAJECTORY, - return_type=return_type, - config=config, - **kwargs) + return cls( + path=path, + mode=LoadingMode.TRAJECTORY, + return_type=return_type, + config=config, + **kwargs, + ) @classmethod - def create_slice_dataset(cls, - path: Text, - slice_length: int = 100, - return_type: str = "numpy", - config: Optional[DatasetConfig] = None, - min_slice_length: Optional[int] = None, - stride: int = 1, - random_start: bool = True, - overlap_ratio: float = 0.0, - **kwargs) -> "VLADataset": + def create_slice_dataset( + cls, + path: Text, + slice_length: int = 100, + return_type: str = "numpy", + config: Optional[DatasetConfig] = None, + min_slice_length: Optional[int] = None, + stride: int = 1, + random_start: bool = True, + overlap_ratio: float = 0.0, + **kwargs, + ) -> "VLADataset": """Create a dataset for loading trajectory slices.""" slice_config = SliceConfig( slice_length=slice_length, @@ -125,12 +134,14 @@ def create_slice_dataset(cls, overlap_ratio=overlap_ratio, ) - return cls(path=path, - mode=LoadingMode.SLICE, - return_type=return_type, - config=config, - slice_config=slice_config, - **kwargs) + return cls( + path=path, + mode=LoadingMode.SLICE, + return_type=return_type, + config=config, + slice_config=slice_config, + **kwargs, + ) def get_ray_dataset(self) -> rd.Dataset: """Get the underlying Ray dataset.""" @@ -245,7 +256,7 @@ def get_stats(self) -> Dict[str, Any]: "total_items": self.count(), "sample_keys": - list(sample.keys()) if isinstance(sample, dict) else [], + (list(sample.keys()) if isinstance(sample, dict) else []), } # Add mode-specific stats @@ -260,8 +271,9 @@ def get_stats(self) -> Dict[str, Any]: first_key = next(iter(sample.keys())) if sample else None if first_key and hasattr(sample[first_key], "__len__"): self._stats["slice_length"] = len(sample[first_key]) - self._stats[ - "slice_start"] = 0 # Cannot determine from direct data + self._stats["slice_start"] = ( + 0 # Cannot determine from direct data + ) self._stats["slice_end"] = len(sample[first_key]) else: self._stats = {"mode": self.mode.value, "total_items": 0} @@ -313,13 +325,15 @@ def get_next_trajectory(self): # Utility functions for common dataset operations -def load_trajectory_dataset(path: Text, - split: str = "all", - return_type: str = "numpy", - batch_size: int = 1, - shuffle: bool = False, - num_parallel_reads: int = 4, - **kwargs) -> VLADataset: +def load_trajectory_dataset( + path: Text, + split: str = "all", + return_type: str = "numpy", + batch_size: int = 1, + shuffle: bool = False, + num_parallel_reads: int = 4, + **kwargs, +) -> VLADataset: """Load a dataset for complete trajectories.""" config = DatasetConfig(batch_size=batch_size, shuffle=shuffle, @@ -330,31 +344,35 @@ def load_trajectory_dataset(path: Text, **kwargs) -def load_slice_dataset(path: Text, - slice_length: int = 100, - split: str = "all", - return_type: str = "numpy", - batch_size: int = 1, - shuffle: bool = False, - num_parallel_reads: int = 4, - min_slice_length: Optional[int] = None, - stride: int = 1, - random_start: bool = True, - overlap_ratio: float = 0.0, - **kwargs) -> VLADataset: +def load_slice_dataset( + path: Text, + slice_length: int = 100, + split: str = "all", + return_type: str = "numpy", + batch_size: int = 1, + shuffle: bool = False, + num_parallel_reads: int = 4, + min_slice_length: Optional[int] = None, + stride: int = 1, + random_start: bool = True, + overlap_ratio: float = 0.0, + **kwargs, +) -> VLADataset: """Load a dataset for trajectory slices.""" config = DatasetConfig(batch_size=batch_size, shuffle=shuffle, num_parallel_reads=num_parallel_reads) - return VLADataset.create_slice_dataset(path=path, - slice_length=slice_length, - return_type=return_type, - config=config, - min_slice_length=min_slice_length, - stride=stride, - random_start=random_start, - overlap_ratio=overlap_ratio, - **kwargs) + return VLADataset.create_slice_dataset( + path=path, + slice_length=slice_length, + return_type=return_type, + config=config, + min_slice_length=min_slice_length, + stride=stride, + random_start=random_start, + overlap_ratio=overlap_ratio, + **kwargs, + ) def split_dataset( diff --git a/robodm/loader/vla.py b/robodm/loader/vla.py index d978101..fc456f1 100644 --- a/robodm/loader/vla.py +++ b/robodm/loader/vla.py @@ -34,8 +34,9 @@ class SliceConfig: """Configuration for slice loading mode.""" slice_length: int = 100 # Number of timesteps per slice - min_slice_length: Optional[ - int] = None # Minimum slice length (defaults to slice_length) + min_slice_length: Optional[int] = ( + None # Minimum slice length (defaults to slice_length) + ) stride: int = 1 # Stride between consecutive timesteps in slice random_start: bool = True # Whether to randomly sample start position overlap_ratio: float = 0.0 # Overlap ratio between consecutive slices (0.0-1.0) diff --git a/robodm/trajectory.py b/robodm/trajectory.py index 8365f3b..e4076d2 100644 --- a/robodm/trajectory.py +++ b/robodm/trajectory.py @@ -321,6 +321,7 @@ def get_supported_pixel_formats(codec_name: str) -> List[str]: """Get list of supported pixel formats for a codec.""" try: import av + codec = av.codec.Codec(codec_name, "w") if codec.video_formats: return [vf.name for vf in codec.video_formats] @@ -329,12 +330,16 @@ def get_supported_pixel_formats(codec_name: str) -> List[str]: return [] @staticmethod - def is_codec_config_supported(width: int, height: int, pix_fmt: str = "yuv420p", codec_name: str = "libx264") -> bool: + def is_codec_config_supported(width: int, + height: int, + pix_fmt: str = "yuv420p", + codec_name: str = "libx264") -> bool: """Check if a specific width/height/pixel format combination is supported by codec.""" try: - import av from fractions import Fraction - + + import av + cc = av.codec.CodecContext.create(codec_name, "w") cc.width = width cc.height = height @@ -347,18 +352,19 @@ def is_codec_config_supported(width: int, height: int, pix_fmt: str = "yuv420p", return False @staticmethod - def is_valid_image_shape(shape: Tuple[int, ...], codec_name: str = "libx264") -> bool: + def is_valid_image_shape(shape: Tuple[int, ...], + codec_name: str = "libx264") -> bool: """Check if a shape can be treated as an RGB image for the given codec.""" # Only accept RGB shapes (H, W, 3) if len(shape) != 3 or shape[2] != 3: return False - + height, width = shape[0], shape[1] - + # Check minimum reasonable image size if height < 1 or width < 1: return False - + # Check codec-specific constraints if codec_name in ["libx264", "libx265"]: # H.264/H.265 require even dimensions @@ -368,9 +374,10 @@ def is_valid_image_shape(shape: Tuple[int, ...], codec_name: str = "libx264") -> # AV1 also typically requires even dimensions for yuv420p if height % 2 != 0 or width % 2 != 0: return False - + # Test if the codec actually supports this resolution - return CodecConfig.is_codec_config_supported(width, height, "yuv420p", codec_name) + return CodecConfig.is_codec_config_supported(width, height, "yuv420p", + codec_name) # Default codec configurations CODEC_CONFIGS = { @@ -428,36 +435,44 @@ def get_codec_for_feature(self, feature_type: FeatureType) -> str: """Determine the appropriate codec for a given feature type.""" data_shape = feature_type.shape - + # Only use video codecs for RGB images (H, W, 3) - if (data_shape is not None and len(data_shape) == 3 and data_shape[2] == 3): + if data_shape is not None and len( + data_shape) == 3 and data_shape[2] == 3: height, width = data_shape[0], data_shape[1] - + # If user specified a codec other than auto, try to use it for RGB images if self.codec != "auto": if self.is_valid_image_shape(data_shape, self.codec): - logger.debug(f"Using user-specified codec {self.codec} for RGB shape {data_shape}") + logger.debug( + f"Using user-specified codec {self.codec} for RGB shape {data_shape}" + ) return self.codec else: - logger.warning(f"User-specified codec {self.codec} doesn't support shape {data_shape}, falling back to rawvideo") + logger.warning( + f"User-specified codec {self.codec} doesn't support shape {data_shape}, falling back to rawvideo" + ) return "rawvideo" - + # Auto-selection for RGB images only codec_preferences = ["libaom-av1", "ffv1", "libx264", "libx265"] - + for codec in codec_preferences: if self.is_valid_image_shape(data_shape, codec): - logger.debug(f"Selected codec {codec} for RGB shape {data_shape}") + logger.debug( + f"Selected codec {codec} for RGB shape {data_shape}") return codec - + # If no video codec works for this RGB image, fall back to rawvideo - logger.warning(f"No video codec supports RGB shape {data_shape}, falling back to rawvideo") - + logger.warning( + f"No video codec supports RGB shape {data_shape}, falling back to rawvideo" + ) + else: # Non-RGB data (grayscale, depth, vectors, etc.) always use rawvideo if data_shape is not None: logger.debug(f"Using rawvideo for non-RGB shape {data_shape}") - + return "rawvideo" def get_pixel_format(self, codec: str, @@ -552,16 +567,16 @@ def __init__( self.feature_name_to_stream: Dict[str, Any] = {} # feature_name: stream - self.feature_name_to_feature_type: Dict[str, FeatureType] = { - } # feature_name: feature_type + self.feature_name_to_feature_type: Dict[str, FeatureType] = ( + {}) # feature_name: feature_type self.trajectory_data = None # trajectory_data self.start_time = self._time() self.mode = mode self.stream_id_to_info: Dict[int, StreamInfo] = {} # stream_id: StreamInfo self.is_closed = False - self.pending_write_tasks: List[Any] = [ - ] # List to keep track of pending write tasks + self.pending_write_tasks: List[Any] = ( + []) # List to keep track of pending write tasks self.container_file: Optional[Any] = None # av.OutputContainer or None # check if the path exists @@ -1035,9 +1050,11 @@ def want(idx: int) -> bool: arr = frame.to_ndarray(format="rgb24") else: # This shouldn't happen with our new logic, but handle gracefully - logger.warning(f"Non-RGB data {fname} with shape {ft.shape} using video codec - this may cause issues") + logger.warning( + f"Non-RGB data {fname} with shape {ft.shape} using video codec - this may cause issues" + ) arr = frame.to_ndarray(format="rgb24") - + if ft.shape: arr = arr.reshape(ft.shape) cache[fname].append(arr) @@ -1082,9 +1099,11 @@ def want(idx: int) -> bool: arr = frame.to_ndarray(format="rgb24") else: # This shouldn't happen with our new logic, but handle gracefully - logger.warning(f"Non-RGB data {fname} with shape {ft.shape} using video codec - this may cause issues") + logger.warning( + f"Non-RGB data {fname} with shape {ft.shape} using video codec - this may cause issues" + ) arr = frame.to_ndarray(format="rgb24") - + if ft.shape: arr = arr.reshape(ft.shape) cache[fname].append(arr) @@ -1176,8 +1195,6 @@ def add( if type(data) == dict: raise ValueError("Use add_by_dict for dictionary") - - feature_type = FeatureType.from_data(data) # encoding = self._get_encoding_of_feature(data, None) self.feature_name_to_feature_type[feature] = feature_type @@ -1430,7 +1447,9 @@ def _load_from_container(self): format="rgb24") # type: ignore[attr-defined] else: # This shouldn't happen with our new logic, but handle gracefully - logger.warning(f"Non-RGB data {feature_name} with shape {shape} using video codec") + logger.warning( + f"Non-RGB data {feature_name} with shape {shape} using video codec" + ) if shape is not None: data = frame.to_ndarray( # type: ignore[attr-defined] format="rgb24").reshape(shape) @@ -1520,8 +1539,8 @@ def _transcode_pickled_images(self, for key, value in stream.metadata.items(): stream_in_updated_container.metadata[key] = value - d_original_stream_id_to_new_container_stream[ - stream.index] = stream_in_updated_container + d_original_stream_id_to_new_container_stream[stream.index] = ( + stream_in_updated_container) # Transcode pickled images and add them to the new container packets_muxed = 0 @@ -1707,8 +1726,8 @@ def _on_new_stream(self, new_feature, new_encoding, new_feature_type): # new_stream.options = stream.options for key, value in stream.metadata.items(): stream_in_updated_container.metadata[key] = value - d_original_stream_id_to_new_container_stream[ - stream.index] = stream_in_updated_container + d_original_stream_id_to_new_container_stream[stream.index] = ( + stream_in_updated_container) # Add new feature stream new_stream = self._add_stream_to_container(new_container, @@ -1778,13 +1797,14 @@ def _create_frame(self, image_array, stream): # Assume float32 values are in [0, 1] range, scale to [0, 255] image_array = np.clip(image_array * 255, 0, 255).astype(np.uint8) elif image_array.dtype != np.uint8: - # Convert other dtypes to uint8 + # Convert other dtypes to uint8 if np.issubdtype(image_array.dtype, np.integer): # For integer types, clamp to 0-255 range image_array = np.clip(image_array, 0, 255).astype(np.uint8) else: # For other types, normalize and convert - image_array = np.clip(image_array * 255, 0, 255).astype(np.uint8) + image_array = np.clip(image_array * 255, 0, + 255).astype(np.uint8) # Only handle RGB images (HxWx3) - no grayscale conversion if len(image_array.shape) == 3 and image_array.shape[2] == 3: @@ -1793,7 +1813,8 @@ def _create_frame(self, image_array, stream): else: raise ValueError( f"Video codecs only support RGB images with shape (H, W, 3). " - f"Got shape {image_array.shape}. Use rawvideo encoding for other formats.") + f"Got shape {image_array.shape}. Use rawvideo encoding for other formats." + ) # Create RGB frame if encoding in ["libaom-av1", "ffv1", "libx264", "libx265"]: diff --git a/tests/test_openx_trajectory.py b/tests/test_openx_trajectory.py index fcab93d..fe1726f 100644 --- a/tests/test_openx_trajectory.py +++ b/tests/test_openx_trajectory.py @@ -377,8 +377,8 @@ def test_openx_codec_availability_report(self, temp_dir, mock_openx_data): print("=" * 60) # Ensure at least one codec works with OpenX data - assert (len(available_codecs) > - 0), "No codecs are available for Open X-Embodiment data!" + assert (len(available_codecs) + > 0), "No codecs are available for Open X-Embodiment data!" class TestRLDSLoaderIntegration: @@ -1702,8 +1702,8 @@ def test_openx_format_comparison(self, temp_dir, openx_test_data, # Ensure file sizes are reasonable (not empty, not too large) for format_name, metrics in successful_formats.items(): - assert (metrics["file_size_mb"] > - 0), f"{format_name} file should not be empty" + assert (metrics["file_size_mb"] + > 0), f"{format_name} file should not be empty" assert (metrics["file_size_mb"] < original_size_mb * 10), f"{format_name} file suspiciously large" @@ -2760,8 +2760,8 @@ def test_openx_loader_scalability(self, temp_dir): f" 💾 {fmt2} is {1/size_ratio:.2f}x more compact than {fmt1}" ) - assert (len(scalability_results) > - 0), "At least one scalability test should succeed" + assert (len(scalability_results) + > 0), "At least one scalability test should succeed" # Test scalability characteristics for format_name in formats: diff --git a/tests/test_shape_codec_logic.py b/tests/test_shape_codec_logic.py index fa10c18..1a220f7 100644 --- a/tests/test_shape_codec_logic.py +++ b/tests/test_shape_codec_logic.py @@ -16,30 +16,38 @@ class TestShapeBasedCodecSelection: def test_rgb_image_codec_selection(self): """Test that RGB images get video codecs when compatible.""" config = CodecConfig() - + # RGB image with even dimensions should get a video codec rgb_even = FeatureType(dtype="uint8", shape=(128, 128, 3)) codec = config.get_codec_for_feature(rgb_even) - assert codec != "rawvideo", f"RGB image with even dimensions should get video codec, got {codec}" - assert codec in ["libx264", "libx265", "libaom-av1", "ffv1"], f"Got unexpected codec: {codec}" + assert ( + codec != "rawvideo" + ), f"RGB image with even dimensions should get video codec, got {codec}" + assert codec in [ + "libx264", + "libx265", + "libaom-av1", + "ffv1", + ], f"Got unexpected codec: {codec}" def test_non_rgb_shapes_use_rawvideo(self): """Test that non-RGB shapes always use rawvideo.""" config = CodecConfig() - + test_cases = [ ((128, 128), "Grayscale image"), - ((10,), "1D vector"), + ((10, ), "1D vector"), ((5, 10), "2D matrix"), ((128, 128, 1), "Single channel image"), ((128, 128, 4), "RGBA image"), ((20, 30, 5), "Multi-channel data"), ] - + for shape, description in test_cases: feature_type = FeatureType(dtype="float32", shape=shape) codec = config.get_codec_for_feature(feature_type) - assert codec == "rawvideo", f"{description} should use rawvideo, got {codec}" + assert (codec == "rawvideo" + ), f"{description} should use rawvideo, got {codec}" def test_user_specified_codec_validation(self): """Test user-specified codec validation for RGB images.""" @@ -47,13 +55,17 @@ def test_user_specified_codec_validation(self): config = CodecConfig(codec="libx264") rgb_even = FeatureType(dtype="uint8", shape=(128, 128, 3)) codec = config.get_codec_for_feature(rgb_even) - assert codec == "libx264", f"Compatible RGB should use user-specified codec, got {codec}" - - # Invalid user-specified codec for incompatible RGB image + assert ( + codec == "libx264" + ), f"Compatible RGB should use user-specified codec, got {codec}" + + # Invalid user-specified codec for incompatible RGB image config = CodecConfig(codec="libx264") rgb_odd = FeatureType(dtype="uint8", shape=(127, 129, 3)) codec = config.get_codec_for_feature(rgb_odd) - assert codec == "rawvideo", f"Incompatible RGB should fall back to rawvideo, got {codec}" + assert ( + codec == "rawvideo" + ), f"Incompatible RGB should fall back to rawvideo, got {codec}" class TestCodecCompatibilityValidation: @@ -64,24 +76,36 @@ def test_is_valid_image_shape(self): test_cases = [ # (shape, codec, expected_result, description) ((128, 128, 3), "libx264", True, "Even dimensions should work"), - ((127, 129, 3), "libx264", False, "Odd dimensions should fail for H.264"), - ((1920, 1080, 3), "libx264", True, "Large even dimensions should work"), - ((2, 2, 3), "libx264", True, "Very small even dimensions might work"), - ((128, 128), "libx264", False, "Non-RGB should not be valid for video codec"), - ((10,), "libx264", False, "1D data should not be valid"), + ((127, 129, 3), "libx264", False, + "Odd dimensions should fail for H.264"), + ((1920, 1080, 3), "libx264", True, + "Large even dimensions should work"), + ((2, 2, 3), "libx264", True, + "Very small even dimensions might work"), + ( + (128, 128), + "libx264", + False, + "Non-RGB should not be valid for video codec", + ), + ((10, ), "libx264", False, "1D data should not be valid"), ] - + for shape, codec, expected, description in test_cases: result = CodecConfig.is_valid_image_shape(shape, codec) - assert result == expected, f"{description}: shape {shape} with {codec} expected {expected}, got {result}" + assert ( + result == expected + ), f"{description}: shape {shape} with {codec} expected {expected}, got {result}" def test_is_codec_config_supported(self): """Test PyAV codec configuration support.""" # These should work for most systems - assert CodecConfig.is_codec_config_supported(128, 128, "yuv420p", "libx264") - + assert CodecConfig.is_codec_config_supported(128, 128, "yuv420p", + "libx264") + # Very large dimensions might not work - large_result = CodecConfig.is_codec_config_supported(10000, 10000, "yuv420p", "libx264") + large_result = CodecConfig.is_codec_config_supported( + 10000, 10000, "yuv420p", "libx264") # Don't assert this as it depends on system capabilities print(f"Large dimensions test result: {large_result}") @@ -92,47 +116,55 @@ class TestRoundtripData: def test_different_shapes_and_types(self): """Test that different data shapes and types can be handled.""" config = CodecConfig() - + test_cases = [ # (shape, dtype, expected_codec_type) ((128, 128, 3), "uint8", "video"), # RGB image ((100, 200, 3), "uint8", "video"), # Different RGB size ((128, 128), "uint8", "rawvideo"), # Grayscale - ((10,), "float32", "rawvideo"), # Vector + ((10, ), "float32", "rawvideo"), # Vector ((5, 10), "float64", "rawvideo"), # Matrix ((128, 128, 1), "uint8", "rawvideo"), # Single channel ((128, 128, 4), "uint8", "rawvideo"), # RGBA ] - + for shape, dtype, expected_type in test_cases: feature_type = FeatureType(dtype=dtype, shape=shape) codec = config.get_codec_for_feature(feature_type) - + if expected_type == "video": - assert codec != "rawvideo", f"Shape {shape} should get video codec, got {codec}" + assert (codec != "rawvideo" + ), f"Shape {shape} should get video codec, got {codec}" else: - assert codec == "rawvideo", f"Shape {shape} should get rawvideo, got {codec}" + assert (codec == "rawvideo" + ), f"Shape {shape} should get rawvideo, got {codec}" def test_mixed_rgb_and_non_rgb_in_trajectory(self): """Test handling mixed RGB and non-RGB data types.""" config = CodecConfig() - + # Simulate mixed data in a trajectory features = { - "camera/rgb": FeatureType(dtype="uint8", shape=(128, 128, 3)), # RGB - "camera/depth": FeatureType(dtype="float32", shape=(128, 128)), # Depth - "robot/joint_pos": FeatureType(dtype="float32", shape=(7,)), # Vector - "camera/mask": FeatureType(dtype="uint8", shape=(128, 128, 1)), # Mask + "camera/rgb": FeatureType(dtype="uint8", + shape=(128, 128, 3)), # RGB + "camera/depth": FeatureType(dtype="float32", + shape=(128, 128)), # Depth + "robot/joint_pos": FeatureType(dtype="float32", + shape=(7, )), # Vector + "camera/mask": FeatureType(dtype="uint8", + shape=(128, 128, 1)), # Mask } - + codecs = {} for name, feature_type in features.items(): codecs[name] = config.get_codec_for_feature(feature_type) - + # Only RGB should get video codec assert codecs["camera/rgb"] != "rawvideo", "RGB should get video codec" - assert codecs["camera/depth"] == "rawvideo", "Depth should get rawvideo" - assert codecs["robot/joint_pos"] == "rawvideo", "Joint positions should get rawvideo" + assert codecs[ + "camera/depth"] == "rawvideo", "Depth should get rawvideo" + assert (codecs["robot/joint_pos"] == "rawvideo" + ), "Joint positions should get rawvideo" assert codecs["camera/mask"] == "rawvideo", "Mask should get rawvideo" @@ -142,41 +174,47 @@ class TestPixelFormatSelection: def test_rgb_pixel_format_selection(self): """Test pixel format selection for RGB data.""" config = CodecConfig() - + rgb_type = FeatureType(dtype="uint8", shape=(128, 128, 3)) - + # Test different codecs yuv_codecs = ["libx264", "libx265", "libaom-av1", "ffv1"] for codec in yuv_codecs: result = config.get_pixel_format(codec, rgb_type) - assert result == "yuv420p", f"RGB data with {codec} should get yuv420p, got {result}" + assert ( + result == "yuv420p" + ), f"RGB data with {codec} should get yuv420p, got {result}" def test_non_rgb_pixel_format_selection(self): """Test pixel format selection for non-RGB data.""" config = CodecConfig() - + # Non-RGB data should not get RGB pixel formats grayscale_type = FeatureType(dtype="uint8", shape=(128, 128)) - vector_type = FeatureType(dtype="float32", shape=(10,)) - + vector_type = FeatureType(dtype="float32", shape=(10, )) + # These should return None (no pixel format for non-RGB) for data_type in [grayscale_type, vector_type]: for codec in ["libx264", "libx265", "libaom-av1", "ffv1"]: result = config.get_pixel_format(codec, data_type) # Should not return RGB-specific formats - assert result is None, f"Non-RGB data should not get pixel format, got {result}" + assert ( + result is None + ), f"Non-RGB data should not get pixel format, got {result}" def test_rawvideo_pixel_format(self): """Test that rawvideo returns None for pixel format.""" config = CodecConfig() - + rgb_type = FeatureType(dtype="uint8", shape=(128, 128, 3)) result = config.get_pixel_format("rawvideo", rgb_type) - assert result is None, f"rawvideo should return None for pixel format, got {result}" + assert ( + result is None + ), f"rawvideo should return None for pixel format, got {result}" @pytest.fixture def temp_dir(): """Create a temporary directory for tests.""" with tempfile.TemporaryDirectory() as tmpdir: - yield tmpdir \ No newline at end of file + yield tmpdir