diff --git a/t4_devkit/tier4.py b/t4_devkit/tier4.py index a4787b0..7b5a93c 100644 --- a/t4_devkit/tier4.py +++ b/t4_devkit/tier4.py @@ -41,7 +41,7 @@ Visibility, ) -__all__ = ["DBMetadata", "load_metadata", "Tier4"] +__all__ = ["DBMetadata", "load_metadata", "load_table", "Tier4"] @define @@ -92,6 +92,29 @@ def load_metadata(db_root: str, revision: str | None = None) -> DBMetadata: return DBMetadata(data_root=data_root, dataset_id=dataset_id, version=version) +def load_table(annotation_dir: str, schema: SchemaName) -> list[SchemaTable]: + """Load schema table from a JSON file. + + If the schema is optional and there is no corresponding JSON file in dataset, + returns empty list. + + Args: + annotation_dir (str): Path to the directory of JSON annotation schema files. + schema (SchemaName): An enum member of `SchemaName`. + + Returns: + Loaded table data saved in `.json`. + """ + filepath = osp.join(annotation_dir, schema.filename) + if not osp.exists(filepath) and schema.is_optional(): + return [] + + if not osp.exists(filepath): + raise FileNotFoundError(f"{schema.value} is mandatory.") + + return build_schema(schema, filepath) + + class Tier4: """Database class for T4 dataset to help query and retrieve information from the database.""" @@ -152,27 +175,29 @@ def __init__( print("======\nLoading T4 tables...") # assign tables explicitly - self.attribute: list[Attribute] = self.__load_table__(SchemaName.ATTRIBUTE) - self.calibrated_sensor: list[CalibratedSensor] = self.__load_table__( - SchemaName.CALIBRATED_SENSOR + self.attribute: list[Attribute] = load_table(self.annotation_dir, SchemaName.ATTRIBUTE) + self.calibrated_sensor: list[CalibratedSensor] = load_table( + self.annotation_dir, SchemaName.CALIBRATED_SENSOR ) - self.category: list[Category] = self.__load_table__(SchemaName.CATEGORY) - self.ego_pose: list[EgoPose] = self.__load_table__(SchemaName.EGO_POSE) - self.instance: list[Instance] = self.__load_table__(SchemaName.INSTANCE) - self.keypoint: list[Keypoint] = self.__load_table__(SchemaName.KEYPOINT) - self.log: list[Log] = self.__load_table__(SchemaName.LOG) - self.map: list[Map] = self.__load_table__(SchemaName.MAP) - self.object_ann: list[ObjectAnn] = self.__load_table__(SchemaName.OBJECT_ANN) - self.sample_annotation: list[SampleAnnotation] = self.__load_table__( - SchemaName.SAMPLE_ANNOTATION + self.category: list[Category] = load_table(self.annotation_dir, SchemaName.CATEGORY) + self.ego_pose: list[EgoPose] = load_table(self.annotation_dir, SchemaName.EGO_POSE) + self.instance: list[Instance] = load_table(self.annotation_dir, SchemaName.INSTANCE) + self.keypoint: list[Keypoint] = load_table(self.annotation_dir, SchemaName.KEYPOINT) + self.log: list[Log] = load_table(self.annotation_dir, SchemaName.LOG) + self.map: list[Map] = load_table(self.annotation_dir, SchemaName.MAP) + self.object_ann: list[ObjectAnn] = load_table(self.annotation_dir, SchemaName.OBJECT_ANN) + self.sample_annotation: list[SampleAnnotation] = load_table( + self.annotation_dir, SchemaName.SAMPLE_ANNOTATION ) - self.sample_data: list[SampleData] = self.__load_table__(SchemaName.SAMPLE_DATA) - self.sample: list[Sample] = self.__load_table__(SchemaName.SAMPLE) - self.scene: list[Scene] = self.__load_table__(SchemaName.SCENE) - self.sensor: list[Sensor] = self.__load_table__(SchemaName.SENSOR) - self.surface_ann: list[SurfaceAnn] = self.__load_table__(SchemaName.SURFACE_ANN) - self.vehicle_state: list[VehicleState] = self.__load_table__(SchemaName.VEHICLE_STATE) - self.visibility: list[Visibility] = self.__load_table__(SchemaName.VISIBILITY) + self.sample_data: list[SampleData] = load_table(self.annotation_dir, SchemaName.SAMPLE_DATA) + self.sample: list[Sample] = load_table(self.annotation_dir, SchemaName.SAMPLE) + self.scene: list[Scene] = load_table(self.annotation_dir, SchemaName.SCENE) + self.sensor: list[Sensor] = load_table(self.annotation_dir, SchemaName.SENSOR) + self.surface_ann: list[SurfaceAnn] = load_table(self.annotation_dir, SchemaName.SURFACE_ANN) + self.vehicle_state: list[VehicleState] = load_table( + self.annotation_dir, SchemaName.VEHICLE_STATE + ) + self.visibility: list[Visibility] = load_table(self.annotation_dir, SchemaName.VISIBILITY) # make reverse indexes for common lookups self.__make_reverse_index__(verbose) @@ -217,27 +242,6 @@ def bag_dir(self) -> str: """Return the path to ROS bag directory.""" return osp.join(self.data_root, "input_bag") - def __load_table__(self, schema: SchemaName) -> list[SchemaTable]: - """Load schema table from a json file. - - If the schema is optional and there is no corresponding json file in dataset, - returns empty list. - - Args: - schema (SchemaName): An enum member of `SchemaName`. - - Returns: - Loaded table data saved in `.json`. - """ - filepath = osp.join(self.annotation_dir, schema.filename) - if not osp.exists(filepath) and schema.is_optional(): - return [] - - if not osp.exists(filepath): - raise FileNotFoundError(f"{schema.value} is mandatory.") - - return build_schema(schema, filepath) - def __make_reverse_index__(self, verbose: bool) -> None: """De-normalize database to create reverse indices for common cases.