Skip to content

Commit 039b5d0

Browse files
authored
Fix usage of optional dependencies (#12)
1 parent 362af2a commit 039b5d0

6 files changed

Lines changed: 437 additions & 407 deletions

File tree

mmlearn/conf/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Hydra/Hydra-zen-based configurations."""
22

33
import functools
4+
import warnings
45
from dataclasses import dataclass, field
56
from enum import Enum
67
from pathlib import Path
@@ -446,6 +447,11 @@ class name is used.
446447
group="trainer/logger",
447448
provider="lightning",
448449
)
450+
else:
451+
warnings.warn(
452+
"wandb is not available. Skipping registration of 'trainer/logger/WandbLogger'.",
453+
stacklevel=1,
454+
)
449455

450456

451457
#################### Custom Hydra Main Decorator ####################

mmlearn/datasets/librispeech.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
import torch
66
import torch.nn.functional as F # noqa: N812
77
from hydra_zen import MISSING, store
8+
from lightning_utilities.core.imports import RequirementCache
89
from torch.utils.data.dataset import Dataset
9-
from torchaudio.datasets import LIBRISPEECH
1010

1111
from mmlearn.constants import EXAMPLE_INDEX_KEY
1212
from mmlearn.datasets.core import Modalities
1313
from mmlearn.datasets.core.example import Example
1414

1515

16+
_TORCHAUDIO_AVAILABLE = RequirementCache("torchaudio>=2.4.0")
1617
SAMPLE_RATE = 16000
1718

1819

@@ -80,6 +81,12 @@ class LibriSpeech(Dataset[Example]):
8081
def __init__(self, root_dir: str, split: str = "train-clean-100") -> None:
8182
"""Initialize LibriSpeech dataset."""
8283
super().__init__()
84+
if not _TORCHAUDIO_AVAILABLE:
85+
raise ImportError(
86+
"LibriSpeech dataset requires `torchaudio` which is not installed."
87+
)
88+
from torchaudio.datasets import LIBRISPEECH
89+
8390
self.dataset = LIBRISPEECH(
8491
root=root_dir,
8592
url=split,

mmlearn/datasets/nyuv2.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import os
44
from typing import Callable, List, Literal, Optional
55

6-
import cv2
76
import numpy as np
87
import torch
98
from hydra_zen import MISSING, store
10-
from PIL import Image
9+
from lightning_utilities.core.imports import RequirementCache
1110
from PIL.Image import Image as PILImage
1211
from torch.utils.data import Dataset
1312
from torchvision.transforms.v2.functional import to_pil_image
@@ -17,6 +16,11 @@
1716
from mmlearn.datasets.core.example import Example
1817

1918

19+
_OPENCV_AVAILABLE = RequirementCache("opencv-python>=4.10.0.84")
20+
if _OPENCV_AVAILABLE:
21+
import cv2 # noqa: F401
22+
23+
2024
_LABELS = [
2125
"bedroom",
2226
"kitchen",
@@ -56,7 +60,7 @@ def depth_normalize(
5660
torch.Tensor
5761
Normalized depth image.
5862
"""
59-
depth_image = np.array(Image.open(depth_file))
63+
depth_image = np.array(PILImage.open(depth_file))
6064
depth = np.array(depth_image).astype(np.float32)
6165
depth_in_meters = depth / 1000.0
6266

@@ -100,6 +104,10 @@ def __init__(
100104
depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
101105
) -> None:
102106
super().__init__()
107+
if not _OPENCV_AVAILABLE:
108+
raise ImportError(
109+
"NYUv2 dataset requires `opencv-python` which is not installed.",
110+
)
103111
self._validate_args(root_dir, split, rgb_transform, depth_transform)
104112
self.return_type = return_type
105113

mmlearn/datasets/sunrgbd.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import os
44
from typing import Callable, List, Literal, Optional
55

6-
import cv2
76
import numpy as np
87
import torch
98
from hydra_zen import MISSING, store
10-
from PIL import Image
9+
from lightning_utilities.core.imports import RequirementCache
1110
from PIL.Image import Image as PILImage
1211
from torch.utils.data import Dataset
1312
from torchvision.transforms.v2.functional import to_pil_image
@@ -17,6 +16,10 @@
1716
from mmlearn.datasets.core.example import Example
1817

1918

19+
_OPENCV_AVAILABLE = RequirementCache("opencv-python>=4.10.0.84")
20+
if _OPENCV_AVAILABLE:
21+
import cv2 # noqa: F401
22+
2023
_LABELS = [
2124
"bathroom",
2225
"bedroom",
@@ -97,7 +100,7 @@ def convert_depth_to_disparity(
97100
lines = fh.readlines()
98101
focal_length = float(lines[0].strip().split()[0])
99102
baseline = sensor_to_params[sensor_type]["baseline"]
100-
depth_image = np.array(Image.open(depth_file))
103+
depth_image = np.array(PILImage.open(depth_file))
101104
depth = np.array(depth_image).astype(np.float32)
102105
depth_in_meters = depth / 1000.0
103106
if min_depth is not None:
@@ -143,6 +146,11 @@ def __init__(
143146
depth_transform: Optional[Callable[[PILImage], torch.Tensor]] = None,
144147
) -> None:
145148
super().__init__()
149+
if not _OPENCV_AVAILABLE:
150+
raise ImportError(
151+
"SUN RGB-D dataset requires `opencv-python` which is not installed.",
152+
)
153+
146154
self._validate_args(root_dir, split, rgb_transform, depth_transform)
147155
self.return_type = return_type
148156

0 commit comments

Comments
 (0)