diff --git a/README.md b/README.md index 9768d2d0..3875893a 100644 --- a/README.md +++ b/README.md @@ -169,11 +169,13 @@ A more accessible, comprehensive, and efficient toolkit for large model compress diff --git a/README_cn.md b/README_cn.md index 6085e27d..8b9d262e 100644 --- a/README_cn.md +++ b/README_cn.md @@ -170,11 +170,13 @@ diff --git a/angelslim/data/audio_dataset.py b/angelslim/data/audio_dataset.py new file mode 100644 index 00000000..0a66e468 --- /dev/null +++ b/angelslim/data/audio_dataset.py @@ -0,0 +1,139 @@ +# Copyright 2025 Tencent Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Dict, List, Union + +import requests +from transformers import ProcessorMixin +from transformers.pipelines.audio_utils import ffmpeg_read + +from .base_dataset import BaseDataset + + +class AudioDataset(BaseDataset): + """Dataset for multimodal (text + image) data""" + + def __init__( + self, + processor: ProcessorMixin, + device: str = "cpu", + max_length: int = 4096, + num_samples: int = -1, + data_source: Union[str, Dict] = None, + is_hf_dataset: bool = False, + model_name: str = None, + ): + super().__init__(processor, device, max_length) + self.is_hf_dataset = is_hf_dataset + self.model_name = model_name + + self._load_file_based_dataset(data_source, num_samples) + + def _load_file_based_dataset(self, data_path: str, num_samples: int): + """Load dataset from local file system""" + audio_dir = os.path.join(os.path.dirname(data_path), "audios") + line_count = 0 + + with open(data_path, "r") as f: + for line in f: + if num_samples > 0 and line_count >= num_samples: + break + + data = json.loads(line.strip()) + if data["audio_path"].startswith("http://") or data[ + "audio_path" + ].startswith("https://"): + audio_path = data["audio_path"] + else: + audio_path = os.path.join(audio_dir, data["audio_path"]) + + # Prepare chat messages with image + messages = [ + { + "role": "user", + "content": [ + {"type": "audio", "audio_url": audio_path}, + { + "type": "text", + "text": data["question"].replace("