Skip to content

Commit 8bdbdb7

Browse files
committed
Add direct encoding option to Trajectory class and optimize stream creation
- Introduced `force_direct_encoding` parameter to `add` and `add_by_dict` methods for direct codec encoding. - Updated stream creation logic to conditionally use direct encoding or fallback to rawvideo. - Enhanced batch data processing in `from_list_of_dicts` and `from_dict_of_lists` methods. - Refactored `PyAVBackend` to support direct encoding and optimized stream handling. - Removed deprecated test file for OpenX trajectory functionality.
1 parent d90695a commit 8bdbdb7

5 files changed

Lines changed: 388 additions & 3396 deletions

File tree

robodm/backend/codec_config.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,17 @@ def is_valid_image_shape(shape: Tuple[int, ...],
6464
# AV1 also typically requires even dimensions for yuv420p
6565
if height % 2 != 0 or width % 2 != 0:
6666
return False
67+
elif codec_name == "ffv1":
68+
# FFV1 can handle odd dimensions but requires minimal size
69+
if height < 2 or width < 2:
70+
return False
6771

6872
# Test if the codec actually supports this resolution
69-
return CodecConfig.is_codec_config_supported(width, height, "yuv420p",
70-
codec_name)
73+
# For FFV1, test with rgb24 instead of yuv420p
74+
if codec_name == "ffv1":
75+
return CodecConfig.is_codec_config_supported(width, height, "rgb24", codec_name)
76+
else:
77+
return CodecConfig.is_codec_config_supported(width, height, "yuv420p", codec_name)
7178

7279
@staticmethod
7380
def is_image_codec(codec_name: str) -> bool:
@@ -370,14 +377,16 @@ def get_pixel_format(self, codec: str, feature_type: FeatureType) -> Optional[st
370377
if codec in self.IMAGE_CODEC_CONFIGS:
371378
base_format = self.IMAGE_CODEC_CONFIGS[codec].get("pixel_format")
372379

373-
# For FFV1, adjust pixel format based on data type
374-
if codec == "ffv1" and feature_type.dtype == "uint8":
380+
# For FFV1, use RGB24 to avoid YUV conversion issues
381+
if codec == "ffv1":
375382
data_shape = feature_type.shape
376383
if data_shape is not None and len(data_shape) == 3:
377384
if data_shape[2] == 3: # RGB
378385
return "rgb24"
379386
elif data_shape[2] == 4: # RGBA
380387
return "rgba"
388+
# Fallback to rgb24 for any other FFV1 case
389+
return "rgb24"
381390

382391
return base_format
383392

robodm/backend/pyav_backend.py

Lines changed: 153 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323

2424
from .base import ContainerBackend, StreamMetadata, PacketInfo, StreamConfig
25-
from robodm.feature import FeatureType
25+
from robodm import FeatureType
2626
from robodm.backend.codec_config import CodecConfig
2727
from .codec_manager import CodecManager
2828

@@ -100,15 +100,28 @@ def encode_data_to_packets(
100100
data: Any,
101101
stream_index: int,
102102
timestamp: int,
103-
codec_config: Any
103+
codec_config: Any,
104+
force_direct_encoding: bool = False
104105
) -> List[PacketInfo]:
105-
"""Encode arbitrary data into packets with timestamp handling"""
106+
"""Encode arbitrary data into packets with timestamp handling
107+
108+
Args:
109+
data: Data to encode
110+
stream_index: Target stream index
111+
timestamp: Timestamp in milliseconds
112+
codec_config: Codec configuration
113+
force_direct_encoding: If True, encode directly to target format instead of rawvideo
114+
"""
106115
if stream_index not in self._idx_to_stream:
107116
raise ValueError(f"No stream with index {stream_index}")
108117

109118
stream = self._idx_to_stream[stream_index]
110119
container_encoding = stream.codec_context.codec.name
111120

121+
# If force_direct_encoding is True, bypass rawvideo intermediate step
122+
if force_direct_encoding and container_encoding != "rawvideo":
123+
return self._encode_directly_to_target(data, stream_index, timestamp, codec_config)
124+
112125
# Create codec if it doesn't exist
113126
codec = self.codec_manager.get_codec_for_stream(stream_index)
114127
if codec is None:
@@ -124,6 +137,37 @@ def encode_data_to_packets(
124137
return packets
125138

126139
return []
140+
141+
def _encode_directly_to_target(self, data: Any, stream_index: int, timestamp: int, codec_config: Any) -> List[PacketInfo]:
142+
"""Encode data directly to the target codec format without intermediate rawvideo step"""
143+
if stream_index not in self._idx_to_stream:
144+
raise ValueError(f"No stream with index {stream_index}")
145+
146+
stream = self._idx_to_stream[stream_index]
147+
container_encoding = stream.codec_context.codec.name
148+
149+
if container_encoding in {"ffv1", "libaom-av1", "libx264", "libx265"}:
150+
# Direct video encoding
151+
if isinstance(data, np.ndarray) and len(data.shape) >= 2:
152+
frame = self._create_frame(data, stream)
153+
frame.time_base = stream.time_base
154+
frame.pts = timestamp
155+
frame.dts = timestamp
156+
157+
packets = []
158+
for pkt in stream.encode(frame): # type: ignore[attr-defined]
159+
packets.append(PacketInfo(
160+
data=bytes(pkt),
161+
pts=pkt.pts,
162+
dts=pkt.dts,
163+
stream_index=stream_index,
164+
time_base=(stream.time_base.numerator, stream.time_base.denominator),
165+
is_keyframe=bool(pkt.is_keyframe) if hasattr(pkt, 'is_keyframe') else False
166+
))
167+
return packets
168+
169+
# Fallback to legacy encoding if direct encoding isn't supported
170+
return self._legacy_encode_fallback(data, stream_index, timestamp, stream)
127171

128172
def _get_feature_type_from_stream(self, stream: Any) -> Any:
129173
"""Extract feature type information from stream metadata"""
@@ -612,12 +656,15 @@ def _create_frame(self, image_array, stream):
612656
f"Got shape {image_array.shape}."
613657
)
614658

615-
# Create RGB frame and convert to YUV420p when required.
616-
if encoding in {"libaom-av1", "ffv1", "libx264", "libx265"}:
617-
frame = av.VideoFrame.from_ndarray(image_array, format="rgb24")
618-
frame = frame.reformat(format="yuv420p")
619-
else:
620-
frame = av.VideoFrame.from_ndarray(image_array, format="rgb24")
659+
# Create RGB frame
660+
frame = av.VideoFrame.from_ndarray(image_array, format="rgb24")
661+
662+
# Get the configured pixel format for this stream
663+
configured_pix_fmt = stream.pix_fmt
664+
665+
# Convert to the configured pixel format if different from RGB24
666+
if configured_pix_fmt and configured_pix_fmt != "rgb24":
667+
frame = frame.reformat(format=configured_pix_fmt)
621668

622669
return frame
623670

@@ -769,4 +816,100 @@ def _transcode_raw_internal(self, packet: Any, output_stream: Any, output_contai
769816

770817
except Exception as e:
771818
logger.error(f"Failed to transcode internal codec: {e}")
772-
return False
819+
return False
820+
821+
def create_streams_for_batch_data(
822+
self,
823+
sample_data: Dict[str, Any],
824+
codec_config: Any,
825+
feature_name_separator: str = "/"
826+
) -> Dict[str, int]:
827+
"""Create optimized streams for batch data processing.
828+
829+
Analyzes sample data to determine optimal codec for each feature
830+
and creates streams with target codec directly.
831+
832+
Args:
833+
sample_data: Sample data dict to analyze feature types
834+
codec_config: Codec configuration
835+
feature_name_separator: Separator for nested feature names
836+
837+
Returns:
838+
Dict mapping feature names to stream indices
839+
"""
840+
if self.container is None:
841+
raise RuntimeError("Container not opened")
842+
843+
from robodm.utils.flatten import _flatten_dict
844+
from robodm import FeatureType
845+
846+
# Flatten the sample data
847+
flattened_data = _flatten_dict(sample_data, sep=feature_name_separator)
848+
849+
feature_to_stream_idx = {}
850+
851+
for feature_name, sample_value in flattened_data.items():
852+
# Determine feature type from sample
853+
feature_type = FeatureType.from_data(sample_value)
854+
855+
# Determine optimal codec for this feature
856+
target_codec = codec_config.get_codec_for_feature(feature_type, feature_name)
857+
container_codec = codec_config.get_container_codec(target_codec)
858+
859+
# Create stream with target codec directly
860+
stream = self.add_stream_for_feature(
861+
feature_name=feature_name,
862+
feature_type=feature_type,
863+
codec_config=codec_config,
864+
encoding=container_codec
865+
)
866+
867+
feature_to_stream_idx[feature_name] = stream.index
868+
869+
logger.debug(f"Created stream for '{feature_name}' with codec '{container_codec}' (target: '{target_codec}')")
870+
871+
return feature_to_stream_idx
872+
873+
def encode_batch_data_directly(
874+
self,
875+
data_batch: List[Dict[str, Any]],
876+
feature_to_stream_idx: Dict[str, int],
877+
codec_config: Any,
878+
feature_name_separator: str = "/",
879+
fps: int = 10
880+
) -> None:
881+
"""Encode a batch of data directly to target codecs without intermediate transcoding.
882+
883+
Args:
884+
data_batch: List of data dictionaries
885+
feature_to_stream_idx: Mapping of feature names to stream indices
886+
codec_config: Codec configuration
887+
feature_name_separator: Separator for nested feature names
888+
fps: Frames per second for timestamp calculation
889+
"""
890+
from robodm.utils.flatten import _flatten_dict
891+
892+
time_interval_ms = 1000 / fps
893+
current_timestamp = 0
894+
895+
for step_data in data_batch:
896+
flattened_data = _flatten_dict(step_data, sep=feature_name_separator)
897+
898+
for feature_name, value in flattened_data.items():
899+
if feature_name in feature_to_stream_idx:
900+
stream_idx = feature_to_stream_idx[feature_name]
901+
902+
# Encode directly to target format
903+
packet_infos = self.encode_data_to_packets(
904+
data=value,
905+
stream_index=stream_idx,
906+
timestamp=int(current_timestamp),
907+
codec_config=codec_config,
908+
force_direct_encoding=True
909+
)
910+
911+
# Mux packets immediately
912+
for packet_info in packet_infos:
913+
self.mux_packet_info(packet_info)
914+
915+
current_timestamp += time_interval_ms

0 commit comments

Comments
 (0)