44from typing import Optional , Union , Any , Dict , List
55from nucleus .constants import (
66 FRAMES_KEY ,
7+ LENGTH_KEY ,
78 METADATA_KEY ,
9+ NUM_SENSORS_KEY ,
810 REFERENCE_ID_KEY ,
911 POINTCLOUD_LOCATION_KEY ,
1012 IMAGE_LOCATION_KEY ,
@@ -25,9 +27,28 @@ def __post_init__(self):
2527 value , DatasetItem
2628 ), "All values must be DatasetItems"
2729
30+ def __repr__ (self ) -> str :
31+ return f"Frame(index={ self .index } , items={ self .items } )"
32+
2833 def add_item (self , item : DatasetItem , sensor_name : str ):
2934 self .items [sensor_name ] = item
3035
36+ def get_item (self , sensor_name : str ):
37+ if sensor_name not in self .items :
38+ raise ValueError (
39+ f"This frame does not have a { sensor_name } sensor"
40+ )
41+ return self .items [sensor_name ]
42+
43+ def get_items (self ):
44+ return list (self .items .values ())
45+
46+ def get_sensors (self ):
47+ return list (self .items .keys ())
48+
49+ def get_index (self ):
50+ return self .index
51+
3152 @classmethod
3253 def from_json (cls , payload : dict ):
3354 items = {
@@ -51,6 +72,9 @@ class Scene(ABC):
5172
5273 def __post_init__ (self ):
5374 self .check_valid_frame_indices ()
75+ self .sensors = set (
76+ flatten ([frame .get_sensors () for frame in self .frames ])
77+ )
5478 if all ((frame .index is not None for frame in self .frames )):
5579 self .frames_dict = {frame .index : frame for frame in self .frames }
5680 else :
@@ -60,6 +84,14 @@ def __post_init__(self):
6084 ]
6185 self .frames_dict = dict (enumerate (indexed_frames ))
6286
87+ @property
88+ def length (self ) -> int :
89+ return len (self .frames_dict )
90+
91+ @property
92+ def num_sensors (self ) -> int :
93+ return len (self .get_sensors ())
94+
6395 def check_valid_frame_indices (self ):
6496 infer_from_list_position = all (
6597 (frame .index is None for frame in self .frames )
@@ -72,15 +104,14 @@ def check_valid_frame_indices(self):
72104 ), "Must specify index explicitly for all frames or infer from list position for all frames"
73105
74106 def validate (self ):
75- assert (
76- len (self .frames_dict ) > 0
77- ), "Must have at least 1 frame in a scene"
107+ assert self .length > 0 , "Must have at least 1 frame in a scene"
78108 for frame in self .frames_dict .values ():
79109 assert isinstance (
80110 frame , Frame
81111 ), "Each frame in a scene must be a Frame object"
82112
83113 def add_item (self , index : int , sensor_name : str , item : DatasetItem ):
114+ self .sensors .add (sensor_name )
84115 if index not in self .frames_dict :
85116 new_frame = Frame (index = index , items = {sensor_name : item })
86117 self .frames_dict [index ] = new_frame
@@ -97,6 +128,54 @@ def add_frame(self, frame: Frame, update: bool = False):
97128 and update
98129 ):
99130 self .frames_dict [frame .index ] = frame
131+ self .sensors .update (frame .get_sensors ())
132+
133+ def get_frame (self , index : int ):
134+ if index not in self .frames_dict :
135+ raise ValueError (
136+ f"This scene does not have a frame at index { index } "
137+ )
138+ return self .frames_dict [index ]
139+
140+ def get_frames (self ):
141+ return [
142+ frame
143+ for _ , frame in sorted (
144+ self .frames_dict .items (), key = lambda x : x [0 ]
145+ )
146+ ]
147+
148+ def get_sensors (self ):
149+ return list (self .sensors )
150+
151+ def get_item (self , index : int , sensor_name : str ):
152+ frame = self .get_frame (index )
153+ return frame .get_item (sensor_name )
154+
155+ def get_items_from_sensor (self , sensor_name : str ):
156+ if sensor_name not in self .sensors :
157+ raise ValueError (
158+ f"This scene does not have a { sensor_name } sensor"
159+ )
160+ items_from_sensor = []
161+ for frame in self .frames_dict .values ():
162+ try :
163+ sensor_item = frame .get_item (sensor_name )
164+ items_from_sensor .append (sensor_item )
165+ except ValueError :
166+ # This sensor is not present at current frame
167+ items_from_sensor .append (None )
168+ return items_from_sensor
169+
170+ def get_items (self ):
171+ return flatten ([frame .get_items () for frame in self .get_frames ()])
172+
173+ def info (self ):
174+ return {
175+ REFERENCE_ID_KEY : self .reference_id ,
176+ LENGTH_KEY : self .length ,
177+ NUM_SENSORS_KEY : self .num_sensors ,
178+ }
100179
101180 def validate_frames_dict (self ):
102181 is_continuous = set (list (range (len (self .frames_dict )))) == set (
@@ -118,12 +197,7 @@ def from_json(cls, payload: dict):
118197
119198 def to_payload (self ) -> dict :
120199 self .validate_frames_dict ()
121- ordered_frames = [
122- frame
123- for _ , frame in sorted (
124- self .frames_dict .items (), key = lambda x : x [0 ]
125- )
126- ]
200+ ordered_frames = self .get_frames ()
127201 frames_payload = [frame .to_payload () for frame in ordered_frames ]
128202 payload : Dict [str , Any ] = {
129203 REFERENCE_ID_KEY : self .reference_id ,
@@ -139,27 +213,30 @@ def to_json(self) -> str:
139213
140214@dataclass
141215class LidarScene (Scene ):
216+ def __repr__ (self ) -> str :
217+ return f"LidarScene(reference_id='{ self .reference_id } ', frames={ self .get_frames ()} , metadata={ self .metadata } )"
218+
142219 def validate (self ):
143220 super ().validate ()
144- lidar_sources = flatten (
221+ lidar_sensors = flatten (
145222 [
146223 [
147- source
148- for source in frame .items .keys ()
149- if frame .items [source ].type == DatasetItemType .POINTCLOUD
224+ sensor
225+ for sensor in frame .items .keys ()
226+ if frame .items [sensor ].type == DatasetItemType .POINTCLOUD
150227 ]
151228 for frame in self .frames_dict .values ()
152229 ]
153230 )
154231 assert (
155- len (set (lidar_sources )) == 1
156- ), "Each lidar scene must have exactly one lidar source "
232+ len (set (lidar_sensors )) == 1
233+ ), "Each lidar scene must have exactly one lidar sensor "
157234
158235 for frame in self .frames_dict .values ():
159236 num_pointclouds = sum (
160237 [
161238 int (item .type == DatasetItemType .POINTCLOUD )
162- for item in frame .items . values ()
239+ for item in frame .get_items ()
163240 ]
164241 )
165242 assert (
@@ -173,17 +250,16 @@ def flatten(t):
173250
174251def check_all_scene_paths_remote (scenes : List [LidarScene ]):
175252 for scene in scenes :
176- for frame in scene .frames_dict .values ():
177- for item in frame .items .values ():
178- pointcloud_location = getattr (item , POINTCLOUD_LOCATION_KEY )
179- if pointcloud_location and is_local_path (pointcloud_location ):
180- raise ValueError (
181- f"All paths for DatasetItems in a Scene must be remote, but { item .pointcloud_location } is either "
182- "local, or a remote URL type that is not supported."
183- )
184- image_location = getattr (item , IMAGE_LOCATION_KEY )
185- if image_location and is_local_path (image_location ):
186- raise ValueError (
187- f"All paths for DatasetItems in a Scene must be remote, but { item .image_location } is either "
188- "local, or a remote URL type that is not supported."
189- )
253+ for item in scene .get_items ():
254+ pointcloud_location = getattr (item , POINTCLOUD_LOCATION_KEY )
255+ if pointcloud_location and is_local_path (pointcloud_location ):
256+ raise ValueError (
257+ f"All paths for DatasetItems in a Scene must be remote, but { item .pointcloud_location } is either "
258+ "local, or a remote URL type that is not supported."
259+ )
260+ image_location = getattr (item , IMAGE_LOCATION_KEY )
261+ if image_location and is_local_path (image_location ):
262+ raise ValueError (
263+ f"All paths for DatasetItems in a Scene must be remote, but { item .image_location } is either "
264+ "local, or a remote URL type that is not supported."
265+ )
0 commit comments