Skip to content

Commit be19191

Browse files
committed
batch fps
1 parent 8bdbdb7 commit be19191

3 files changed

Lines changed: 68 additions & 16 deletions

File tree

robodm/backend/pyav_backend.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import pickle
1717
import logging
1818
from fractions import Fraction
19-
from typing import Any, Dict, List, Tuple, Optional
19+
from typing import Any, Dict, List, Tuple, Optional, Union
2020

2121
import av
2222
import numpy as np
@@ -822,17 +822,20 @@ def create_streams_for_batch_data(
822822
self,
823823
sample_data: Dict[str, Any],
824824
codec_config: Any,
825-
feature_name_separator: str = "/"
825+
feature_name_separator: str = "/",
826+
visualization_feature: Optional[str] = None
826827
) -> Dict[str, int]:
827828
"""Create optimized streams for batch data processing.
828829
829830
Analyzes sample data to determine optimal codec for each feature
830-
and creates streams with target codec directly.
831+
and creates streams with target codec directly. Respects visualization_feature
832+
ordering to prioritize visualization streams first.
831833
832834
Args:
833835
sample_data: Sample data dict to analyze feature types
834836
codec_config: Codec configuration
835837
feature_name_separator: Separator for nested feature names
838+
visualization_feature: Optional feature name to prioritize as first stream for visualization
836839
837840
Returns:
838841
Dict mapping feature names to stream indices
@@ -846,9 +849,30 @@ def create_streams_for_batch_data(
846849
# Flatten the sample data
847850
flattened_data = _flatten_dict(sample_data, sep=feature_name_separator)
848851

852+
# Sort features to prioritize visualization feature
853+
def get_feature_priority(item):
854+
feature_name, sample_value = item
855+
856+
# Highest priority: specified visualization_feature
857+
if visualization_feature and feature_name == visualization_feature:
858+
return (0, feature_name)
859+
860+
# Second priority: features that will become video-encoded (images/visualizations)
861+
feature_type = FeatureType.from_data(sample_value)
862+
target_codec = codec_config.get_codec_for_feature(feature_type, feature_name)
863+
container_codec = codec_config.get_container_codec(target_codec)
864+
if container_codec in {"ffv1", "libaom-av1", "libx264", "libx265"}:
865+
return (1, feature_name)
866+
867+
# Third priority: everything else
868+
return (2, feature_name)
869+
870+
# Sort features by priority
871+
sorted_features = sorted(flattened_data.items(), key=get_feature_priority)
872+
849873
feature_to_stream_idx = {}
850874

851-
for feature_name, sample_value in flattened_data.items():
875+
for feature_name, sample_value in sorted_features:
852876
# Determine feature type from sample
853877
feature_type = FeatureType.from_data(sample_value)
854878

@@ -866,7 +890,7 @@ def create_streams_for_batch_data(
866890

867891
feature_to_stream_idx[feature_name] = stream.index
868892

869-
logger.debug(f"Created stream for '{feature_name}' with codec '{container_codec}' (target: '{target_codec}')")
893+
logger.debug(f"Created stream for '{feature_name}' with codec '{container_codec}' (target: '{target_codec}') at index {stream.index}")
870894

871895
return feature_to_stream_idx
872896

@@ -876,7 +900,7 @@ def encode_batch_data_directly(
876900
feature_to_stream_idx: Dict[str, int],
877901
codec_config: Any,
878902
feature_name_separator: str = "/",
879-
fps: int = 10
903+
fps: Union[int, Dict[str, int]] = 10
880904
) -> None:
881905
"""Encode a batch of data directly to target codecs without intermediate transcoding.
882906
@@ -885,12 +909,32 @@ def encode_batch_data_directly(
885909
feature_to_stream_idx: Mapping of feature names to stream indices
886910
codec_config: Codec configuration
887911
feature_name_separator: Separator for nested feature names
888-
fps: Frames per second for timestamp calculation
912+
fps: Frames per second for timestamp calculation. Can be an int (same fps for all features) or Dict[str, int] (per-feature fps)
889913
"""
890914
from robodm.utils.flatten import _flatten_dict
891915

892-
time_interval_ms = 1000 / fps
893-
current_timestamp = 0
916+
# Handle fps parameter - can be int or dict
917+
if isinstance(fps, int):
918+
# Use same fps for all features
919+
default_fps = fps
920+
feature_fps = {}
921+
else:
922+
# Per-feature fps specified
923+
feature_fps = fps
924+
default_fps = 10 # Fallback default
925+
926+
# Initialize per-feature timestamps and time intervals
927+
feature_timestamps = {}
928+
feature_time_intervals = {}
929+
930+
# Get all feature names from first sample to initialize timestamps
931+
if data_batch:
932+
first_sample = _flatten_dict(data_batch[0], sep=feature_name_separator)
933+
for feature_name in first_sample.keys():
934+
if feature_name in feature_to_stream_idx:
935+
fps_for_feature = feature_fps.get(feature_name, default_fps)
936+
feature_timestamps[feature_name] = 0
937+
feature_time_intervals[feature_name] = 1000.0 / fps_for_feature
894938

895939
for step_data in data_batch:
896940
flattened_data = _flatten_dict(step_data, sep=feature_name_separator)
@@ -899,6 +943,9 @@ def encode_batch_data_directly(
899943
if feature_name in feature_to_stream_idx:
900944
stream_idx = feature_to_stream_idx[feature_name]
901945

946+
# Get current timestamp for this feature
947+
current_timestamp = feature_timestamps.get(feature_name, 0)
948+
902949
# Encode directly to target format
903950
packet_infos = self.encode_data_to_packets(
904951
data=value,
@@ -911,5 +958,7 @@ def encode_batch_data_directly(
911958
# Mux packets immediately
912959
for packet_info in packet_infos:
913960
self.mux_packet_info(packet_info)
914-
915-
current_timestamp += time_interval_ms
961+
962+
# Update timestamp for this feature
963+
time_interval = feature_time_intervals.get(feature_name, 1000.0 / default_fps)
964+
feature_timestamps[feature_name] = current_timestamp + time_interval

robodm/trajectory.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def from_list_of_dicts(
783783
video_codec: str = "auto",
784784
codec_options: Optional[Dict[str, Any]] = None,
785785
visualization_feature: Optional[Text] = None,
786-
fps: Optional[int] = 10,
786+
fps: Optional[Union[int, Dict[str, int]]] = 10,
787787
raw_codec: Optional[str] = None,
788788
) -> "Trajectory":
789789
"""
@@ -795,6 +795,7 @@ def from_list_of_dicts(
795795
video_codec (str, optional): Video codec to use for video/image features. Defaults to "auto".
796796
codec_options (Dict[str, Any], optional): Additional codec-specific options.
797797
visualization_feature: Optional feature name to prioritize as first stream for visualization.
798+
fps: Optional fps for features. Can be an int (same fps for all features) or Dict[str, int] (per-feature fps).
798799
raw_codec (str, optional): Raw codec to use for non-image features. Defaults to None.
799800
800801
Example:
@@ -822,7 +823,8 @@ def from_list_of_dicts(
822823
feature_to_stream_idx = traj.backend.create_streams_for_batch_data(
823824
sample_data=sample_data,
824825
codec_config=traj.codec_config,
825-
feature_name_separator=traj.feature_name_separator
826+
feature_name_separator=traj.feature_name_separator,
827+
visualization_feature=visualization_feature
826828
)
827829

828830
# Update feature type tracking for consistency
@@ -854,7 +856,7 @@ def from_dict_of_lists(
854856
video_codec: str = "auto",
855857
codec_options: Optional[Dict[str, Any]] = None,
856858
visualization_feature: Optional[Text] = None,
857-
fps: Optional[int] = 10,
859+
fps: Optional[Union[int, Dict[str, int]]] = 10,
858860
raw_codec: Optional[str] = None,
859861
) -> "Trajectory":
860862
"""
@@ -867,6 +869,7 @@ def from_dict_of_lists(
867869
video_codec (str, optional): Video codec to use for video/image features. Defaults to "auto".
868870
codec_options (Dict[str, Any], optional): Additional codec-specific options.
869871
visualization_feature: Optional feature name to prioritize as first stream for visualization.
872+
fps: Optional fps for features. Can be an int (same fps for all features) or Dict[str, int] (per-feature fps).
870873
raw_codec (str, optional): Raw codec to use for non-image features. Defaults to None.
871874
872875
Returns:

robodm/trajectory_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def from_list_of_dicts(
9797
video_codec: str = "auto",
9898
codec_options: Optional[Dict[str, Any]] = None,
9999
visualization_feature: Optional[Text] = None,
100-
fps: Optional[int] = 10,
100+
fps: Optional[Union[int, Dict[str, int]]] = 10,
101101
raw_codec: Optional[str] = None,
102102
) -> "TrajectoryInterface":
103103
"""
@@ -124,7 +124,7 @@ def from_dict_of_lists(
124124
video_codec: str = "auto",
125125
codec_options: Optional[Dict[str, Any]] = None,
126126
visualization_feature: Optional[Text] = None,
127-
fps: Optional[int] = 10,
127+
fps: Optional[Union[int, Dict[str, int]]] = 10,
128128
raw_codec: Optional[str] = None,
129129
) -> "TrajectoryInterface":
130130
"""

0 commit comments

Comments
 (0)