|
3 | 3 | import os |
4 | 4 | from typing import Callable, List, Literal, Optional |
5 | 5 |
|
6 | | -import cv2 |
7 | 6 | import numpy as np |
8 | 7 | import torch |
9 | 8 | from hydra_zen import MISSING, store |
10 | | -from PIL import Image |
| 9 | +from lightning_utilities.core.imports import RequirementCache |
11 | 10 | from PIL.Image import Image as PILImage |
12 | 11 | from torch.utils.data import Dataset |
13 | 12 | from torchvision.transforms.v2.functional import to_pil_image |
|
17 | 16 | from mmlearn.datasets.core.example import Example |
18 | 17 |
|
19 | 18 |
|
| 19 | +_OPENCV_AVAILABLE = RequirementCache("opencv-python>=4.10.0.84") |
| 20 | +if _OPENCV_AVAILABLE: |
| 21 | + import cv2 # noqa: F401 |
| 22 | + |
20 | 23 | _LABELS = [ |
21 | 24 | "bathroom", |
22 | 25 | "bedroom", |
@@ -97,7 +100,7 @@ def convert_depth_to_disparity( |
97 | 100 | lines = fh.readlines() |
98 | 101 | focal_length = float(lines[0].strip().split()[0]) |
99 | 102 | baseline = sensor_to_params[sensor_type]["baseline"] |
100 | | - depth_image = np.array(Image.open(depth_file)) |
| 103 | + depth_image = np.array(PILImage.open(depth_file)) |
101 | 104 | depth = np.array(depth_image).astype(np.float32) |
102 | 105 | depth_in_meters = depth / 1000.0 |
103 | 106 | if min_depth is not None: |
@@ -143,6 +146,11 @@ def __init__( |
143 | 146 | depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None, |
144 | 147 | ) -> None: |
145 | 148 | super().__init__() |
| 149 | + if not _OPENCV_AVAILABLE: |
| 150 | + raise ImportError( |
| 151 | + "SUN RGB-D dataset requires `opencv-python` which is not installed.", |
| 152 | + ) |
| 153 | + |
146 | 154 | self._validate_args(root_dir, split, rgb_transform, depth_transform) |
147 | 155 | self.return_type = return_type |
148 | 156 |
|
|
0 commit comments