Skip to content

Commit 2674b44

Browse files
authored
fix mutli audio sources (#3343)
* fix mutli audio sources * typing
1 parent a795373 commit 2674b44

1 file changed

Lines changed: 6 additions & 5 deletions

File tree

libs/libcommon/src/libcommon/viewer_utils/features.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
SUPPORTED_AUDIO_EXTENSIONS,
4141
AudioSource,
4242
ImageSource,
43+
PDFSource,
4344
VideoSource,
4445
create_audio_file,
4546
create_image_file,
@@ -97,7 +98,7 @@ def image(
9798
hf_endpoint: str,
9899
hf_token: Optional[str],
99100
json_path: Optional[list[Union[str, int]]] = None,
100-
) -> Any:
101+
) -> Optional[ImageSource]:
101102
if value is None:
102103
return None
103104
if isinstance(value, dict) and value.get("bytes"):
@@ -152,7 +153,7 @@ def audio(
152153
storage_client: StorageClient,
153154
hf_endpoint: str,
154155
json_path: Optional[list[Union[str, int]]] = None,
155-
) -> Any:
156+
) -> Optional[list[AudioSource]]:
156157
from datasets.features._torchcodec import AudioDecoder
157158

158159
if value is None:
@@ -177,7 +178,7 @@ def audio(
177178
if audio_file_extension in SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE:
178179
if value["path"].startswith(f"hf://datasets/{dataset}@"):
179180
src = value["path"].replace("hf://", hf_endpoint + "/", 1).replace("@", "/resolve/", 1)
180-
return AudioSource(src=src, type=SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE[audio_file_extension])
181+
return [AudioSource(src=src, type=SUPPORTED_AUDIO_EXTENSION_TO_MEDIA_TYPE[audio_file_extension])]
181182

182183
audio_file_bytes = get_audio_file_bytes(value)
183184
if not audio_file_extension:
@@ -281,7 +282,7 @@ def video(
281282
storage_client: StorageClient,
282283
hf_endpoint: str,
283284
json_path: Optional[list[Union[str, int]]] = None,
284-
) -> Any:
285+
) -> Optional[VideoSource]:
285286
if datasets.config.TORCHCODEC_AVAILABLE:
286287
from torchcodec.decoders import VideoDecoder
287288

@@ -372,7 +373,7 @@ def pdf(
372373
hf_endpoint: str,
373374
hf_token: Optional[str],
374375
json_path: Optional[list[Union[str, int]]] = None,
375-
) -> Any:
376+
) -> Optional[PDFSource]:
376377
if value is None:
377378
return None
378379
if isinstance(value, dict) and value.get("bytes"):

0 commit comments

Comments
 (0)