Skip to content

Commit c60f75c

Browse files
b-enedictchristinadionysio
authored andcommitted
[SYSTEMDS-3944] Modality Alignment, Contrastive Learning, new Data Loaders
Summary This PR introduces new functionality for multimodal learning in Scuro, including a contrastive learning operator, a modality alignment operator, and additional data loaders. Changes Contrastive Learning Operator - Constructs modality pairs via a Cartesian product - Uses a user-defined function to label pairs as positive or negative - Enables dynamic generation of contrastive samples Modality Alignment Operator - Aligns previously unaligned modalities using feature-based similarity (e.g., ORB, perceptual hashing) - Outputs a matching between a primary and secondary modality - Matching is applied after representation learning and before fusion Data Loaders - PDF loader: converts document pages into NumPy arrays for OpenCV processing - Audio loader: converts audio to text using faster-whisper Closes #2461
1 parent c2d0e5a commit c60f75c

32 files changed

Lines changed: 557 additions & 124 deletions

src/main/python/systemds/scuro/dataloader/audio_loader.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,17 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
6363
if not self.load_data_from_file:
6464
import numpy as np
6565

66-
self.metadata[file] = self.modality_type.create_metadata(
67-
1000, np.array([0])
68-
)
66+
audio = np.array([0])
67+
sr = 1000
6968
else:
7069
audio, sr = librosa.load(file, dtype=self._data_type)
7170

7271
if self.normalize:
7372
audio = librosa.util.normalize(audio)
7473

75-
self.metadata[file] = self.modality_type.create_metadata(sr, audio)
74+
self.metadata.append(self.modality_type.create_metadata(sr, audio))
7675

77-
self.data.append(audio)
76+
self.data.append(audio)
7877

7978
def get_stats(self, source_path: str):
8079
sampling_rate = 0

src/main/python/systemds/scuro/dataloader/base_loader.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def __init__(
4444
(otherwise please provide your own Dataloader that knows about the file name convention)
4545
"""
4646
self.data = []
47-
self.metadata = (
48-
{}
49-
) # TODO: check what the index should be for storing the metadata (file_name, counter, ...)
47+
self.metadata = []
5048
self.source_path = source_path
5149
self.indices = indices
5250
self.modality_type = modality_type
@@ -87,7 +85,7 @@ def data_type(self, data_type):
8785
def reset(self):
8886
self._next_chunk = 0
8987
self.data = []
90-
self.metadata = {}
88+
self.metadata = []
9189

9290
def load(self):
9391
"""
@@ -134,6 +132,7 @@ def _load_next_chunk(self):
134132
Loads the next chunk of data
135133
"""
136134
self.data = []
135+
# TODO: Handle metadata correctly
137136
next_chunk_indices = self.indices[
138137
self._next_chunk
139138
* self._chunk_size : (self._next_chunk + 1)

src/main/python/systemds/scuro/dataloader/image_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
7171

7272
image = image.astype(np.uint8, copy=False)
7373

74-
self.metadata[file] = self.modality_type.create_metadata(
75-
width, height, channels
74+
self.metadata.append(
75+
self.modality_type.create_metadata(width, height, channels)
7676
)
7777

7878
self.data.append(image)

src/main/python/systemds/scuro/dataloader/json_loader.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
6969

7070
text = " ".join(text) if isinstance(text, list) else text
7171
self.data.append(text)
72-
self.metadata[idx] = self.modality_type.create_metadata(len(text), text)
72+
self.metadata.append(
73+
self.modality_type.create_metadata(len(text), text) | json_file[idx]
74+
)
7375

7476
def get_stats(self, source_path: str):
7577
self.file_sanity_check(source_path)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
from typing import List, Optional, Union
22+
import pymupdf
23+
24+
import numpy as np
25+
26+
from systemds.scuro.dataloader.base_loader import BaseLoader
27+
import cv2
28+
from systemds.scuro.modality.type import ModalityType
29+
30+
31+
class PdfLoader(BaseLoader):
32+
def __init__(
33+
self,
34+
source_path: str,
35+
indices: List[str],
36+
data_type: Union[np.dtype, str] = np.float16,
37+
chunk_size: Optional[int] = None,
38+
load=True,
39+
ext=".pdf",
40+
):
41+
super().__init__(
42+
source_path, indices, data_type, chunk_size, ModalityType.IMAGE, ext
43+
)
44+
self.load_data_from_file = load
45+
46+
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
47+
self.file_sanity_check(file)
48+
49+
doc = pymupdf.open(file)
50+
51+
for i, page in enumerate(doc.pages()):
52+
image_bytes = page.get_pixmap().tobytes("jpg")
53+
np_buffer = np.frombuffer(image_bytes, dtype=np.uint8)
54+
55+
image = cv2.imdecode(np_buffer, cv2.IMREAD_COLOR)
56+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
57+
58+
if image.ndim == 2:
59+
height, width = image.shape
60+
channels = 1
61+
else:
62+
height, width, channels = image.shape
63+
64+
image = image.astype(np.uint8, copy=False)
65+
66+
self.metadata.append(
67+
self.modality_type.create_metadata(width, height, channels)
68+
)
69+
70+
self.data.append(image)

src/main/python/systemds/scuro/dataloader/text_loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
5656
if self.prefix:
5757
line = re.sub(self.prefix, "", line)
5858
line = line.replace("\n", "")
59-
self.metadata[file] = self.modality_type.create_metadata(
60-
len(line.split()), line
59+
self.metadata.append(
60+
self.modality_type.create_metadata(len(line.split()), line)
6161
)
6262
self.data.append(line)
6363

src/main/python/systemds/scuro/dataloader/timeseries_loader.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,20 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
8181
data = self._normalize_signals(data)
8282

8383
if file:
84-
self.metadata[index] = self.modality_type.create_metadata(
85-
self.signal_names, data, self.sampling_rate
84+
self.metadata.append(
85+
self.modality_type.create_metadata(
86+
self.signal_names, data, self.sampling_rate
87+
)
8688
)
89+
self.data.append(data)
8790
else:
8891
for i, index in enumerate(self.indices):
89-
self.metadata[str(index)] = self.modality_type.create_metadata(
90-
self.signal_names, data[i], self.sampling_rate
92+
self.metadata.append(
93+
self.modality_type.create_metadata(
94+
self.signal_names, data[i], self.sampling_rate
95+
)
9196
)
92-
self.data.append(data)
97+
self.data.append(data[i])
9398

9499
def _normalize_signals(self, data: np.ndarray) -> np.ndarray:
95100
if data.ndim == 1:
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# -------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
# -------------------------------------------------------------
21+
from typing import List, Optional, Union
22+
from faster_whisper import WhisperModel
23+
import numpy as np
24+
25+
from systemds.scuro.dataloader.base_loader import BaseLoader
26+
from systemds.scuro.modality.type import ModalityType
27+
28+
29+
class TranscriptLoader(BaseLoader):
30+
def __init__(
31+
self,
32+
source_path: str,
33+
indices: List[str],
34+
data_type: Union[np.dtype, str] = np.float32,
35+
chunk_size: Optional[int] = None,
36+
normalize: bool = True,
37+
transcribe_model_size: str = "medium",
38+
load=True,
39+
):
40+
super().__init__(source_path, indices, data_type, chunk_size, ModalityType.TEXT)
41+
self.model = WhisperModel(
42+
transcribe_model_size, device="cpu", compute_type="int8"
43+
)
44+
self.normalize = normalize
45+
self.load_data_from_file = load
46+
47+
def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
48+
self.file_sanity_check(file)
49+
segments, _ = self.model.transcribe(file, vad_filter=True)
50+
51+
for i, seg in enumerate(segments):
52+
md = self.modality_type.create_metadata(len(seg.text.split()), seg.text)
53+
md["timestamp_start"] = seg.start
54+
md["timestamp_end"] = seg.end
55+
md["text"] = seg.text
56+
57+
self.metadata.append(md)
58+
59+
self.data.append(seg.text)

src/main/python/systemds/scuro/dataloader/video_loader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def extract(self, file: str, index: Optional[Union[str, List[str]]] = None):
8787
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
8888
num_channels = 3
8989

90-
self.metadata[file] = self.modality_type.create_metadata(
91-
self.fps, length, width, height, num_channels
90+
self.metadata.append(
91+
self.modality_type.create_metadata(
92+
self.fps, length, width, height, num_channels
93+
)
9294
)
9395

9496
frames = []

src/main/python/systemds/scuro/modality/joined.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,8 @@ def execute(self, starting_idx=0):
7777
)
7878

7979
for i in range(start, end):
80-
idx_1 = list(self.left_modality.metadata.values())[i + starting_idx][
81-
self.condition.leftField
82-
]
80+
left_meta_idx = i if self.chunk_left else i + starting_idx
81+
idx_1 = self.left_modality.metadata[left_meta_idx][self.condition.leftField]
8382
if (
8483
self.condition.alignment is None and self.condition.join_type == "<"
8584
): # TODO compute correct alignment timestamps/spatial params
@@ -90,9 +89,7 @@ def execute(self, starting_idx=0):
9089
if self.chunk_left:
9190
i = i + starting_idx
9291

93-
idx_2 = list(self.right_modality.metadata.values())[i][
94-
self.condition.rightField
95-
]
92+
idx_2 = self.right_modality.metadata[i][self.condition.rightField]
9693
self.joined_right.data.append([])
9794

9895
c = 0
@@ -228,8 +225,8 @@ def _handle_chunked_execution(self, representation):
228225
def _apply_representation_chunked(
229226
self, left_modality, right_modality, chunk_right, representation
230227
):
231-
new_left = Modality(left_modality.modality_type, {})
232-
new_right = Modality(right_modality.modality_type, {})
228+
new_left = Modality(left_modality.modality_type)
229+
new_right = Modality(right_modality.modality_type)
233230

234231
for _ in left_modality.iter_raw_data_chunks(reset=True):
235232
if chunk_right:
@@ -246,11 +243,11 @@ def _apply_representation_chunked(
246243
self.joined_right, representation
247244
)
248245
new_right.data.extend(right_transformed.data)
249-
new_right.metadata.update(right_transformed.metadata)
246+
new_right.metadata.extend(right_transformed.metadata)
250247

251248
left_transformed = self._apply_representation(left_modality, representation)
252249
new_left.data.extend(left_transformed.data)
253-
new_left.metadata.update(left_transformed.metadata)
250+
new_left.metadata.extend(left_transformed.metadata)
254251

255252
new_left.update_metadata()
256253
new_right.update_metadata()

0 commit comments

Comments
 (0)