diff --git a/nemo_curator/stages/audio/__init__.py b/nemo_curator/stages/audio/__init__.py index df541d370e..c2663570cc 100644 --- a/nemo_curator/stages/audio/__init__.py +++ b/nemo_curator/stages/audio/__init__.py @@ -36,9 +36,14 @@ SIGMOSFilterStage, UTMOSFilterStage, ) -from nemo_curator.stages.audio.postprocessing import ( - TimestampMapperStage, +from nemo_curator.stages.audio.io import ( + AudioManifestReader, + AudioToDocumentStage, + CleanupTemporaryAudioStage, + MaterializeTarredAudioStage, + TarredAudioManifestReader, ) +from nemo_curator.stages.audio.postprocessing import TimestampMapperStage from nemo_curator.stages.audio.preprocessing import ( MonoConversionStage, SegmentConcatenationStage, @@ -52,15 +57,20 @@ "ALMDataBuilderStage", "ALMDataOverlapStage", "AudioDataFilterStage", + "AudioManifestReader", + "AudioToDocumentStage", "BandFilterStage", + "CleanupTemporaryAudioStage", "GetAudioDurationStage", "ManifestReader", "ManifestWriterStage", + "MaterializeTarredAudioStage", "MonoConversionStage", "PreserveByValueStage", "SIGMOSFilterStage", "SegmentConcatenationStage", "SpeakerSeparationStage", + "TarredAudioManifestReader", "TimestampMapperStage", "UTMOSFilterStage", "VADSegmentationStage", diff --git a/nemo_curator/stages/audio/io/__init__.py b/nemo_curator/stages/audio/io/__init__.py index e69de29bb2..f34438f35d 100644 --- a/nemo_curator/stages/audio/io/__init__.py +++ b/nemo_curator/stages/audio/io/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.audio.io.convert import AudioToDocumentStage +from nemo_curator.stages.audio.io.manifest import AudioManifestReader, AudioManifestReaderStage +from nemo_curator.stages.audio.io.materialize import CleanupTemporaryAudioStage +from nemo_curator.stages.audio.io.tarred import ( + MaterializeTarredAudioStage, + TarredAudioManifestPartitionStage, + TarredAudioManifestReader, + TarredAudioManifestReaderStage, +) + +__all__ = [ + "AudioManifestReader", + "AudioManifestReaderStage", + "AudioToDocumentStage", + "CleanupTemporaryAudioStage", + "MaterializeTarredAudioStage", + "TarredAudioManifestPartitionStage", + "TarredAudioManifestReader", + "TarredAudioManifestReaderStage", +] diff --git a/nemo_curator/stages/audio/io/manifest.py b/nemo_curator/stages/audio/io/manifest.py new file mode 100644 index 0000000000..35e0edc2cd --- /dev/null +++ b/nemo_curator/stages/audio/io/manifest.py @@ -0,0 +1,123 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any, Literal + +from nemo_curator.stages.base import CompositeStage, ProcessingStage +from nemo_curator.stages.file_partitioning import FilePartitioningStage +from nemo_curator.tasks import AudioTask, FileGroupTask, _EmptyTask +from nemo_curator.tasks.audio_task import build_audio_sample_key +from nemo_curator.utils.remote_io import open_text_stream + + +@dataclass +class AudioManifestReaderStage(ProcessingStage[FileGroupTask, AudioTask]): + fields: list[str] | None = None + storage_options: dict[str, Any] | None = None + transport: Literal["auto", "fsspec", "pipe"] = "auto" + audio_filepath_key: str = "audio_filepath" + manifest_path_key: str = "_manifest_path" + source_type_key: str = "_audio_source_type" + source_type_value: str = "manifest" + name: str = "audio_manifest_reader" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + output_fields = list(self.fields or []) + if self.audio_filepath_key not in output_fields: + output_fields.append(self.audio_filepath_key) + output_fields.extend([self.manifest_path_key, self.source_type_key]) + return ["sample_key"], output_fields + + def process(self, task: FileGroupTask) -> list[AudioTask]: + results: list[AudioTask] = [] + for manifest_index, manifest_path in enumerate(task.data): + with open_text_stream( + manifest_path, + storage_options=self.storage_options, + transport=self.transport, + ) as fin: + for entry_index, line in enumerate(fin): + if not line.strip(): + continue + raw_entry = json.loads(line) + sample_key = build_audio_sample_key(raw_entry, dataset_name=task.dataset_name) + entry = dict(raw_entry) + if self.fields is not None: + entry = {field: entry[field] for field in self.fields if field in entry} + if self.audio_filepath_key in raw_entry and self.audio_filepath_key not in entry: + entry[self.audio_filepath_key] = raw_entry[self.audio_filepath_key] + entry[self.manifest_path_key] = manifest_path + entry[self.source_type_key] = self.source_type_value + results.append( + AudioTask( + task_id=f"{task.task_id}_{manifest_index}_{entry_index}", + dataset_name=task.dataset_name, + data=entry, + sample_key=sample_key, + _metadata=task._metadata, + _stage_perf=list(task._stage_perf), + ) + ) + return results + + +@dataclass +class AudioManifestReader(CompositeStage[_EmptyTask, AudioTask]): + manifest_paths: str | list[str] + files_per_partition: int | None = 1 + blocksize: int | str | None = None + file_extensions: list[str] = field(default_factory=lambda: [".jsonl", ".json"]) + fields: list[str] | None = None + storage_options: dict[str, Any] | None = None + transport: Literal["auto", "fsspec", "pipe"] = "auto" + limit: int | None = None + name: str = "audio_manifest_reader" + + def __post_init__(self) -> None: + super().__init__() + if not self.manifest_paths: + msg = "manifest_paths is required for AudioManifestReader" + raise ValueError(msg) + + def decompose(self) -> list[ProcessingStage]: + return [ + FilePartitioningStage( + file_paths=self.manifest_paths, + files_per_partition=self.files_per_partition, + blocksize=self.blocksize, + file_extensions=self.file_extensions, + storage_options=self.storage_options, + limit=self.limit, + ), + AudioManifestReaderStage( + fields=self.fields, + storage_options=self.storage_options, + transport=self.transport, + ), + ] + + def get_description(self) -> str: + parts = [f"Read audio manifests from {self.manifest_paths}"] + if self.files_per_partition: + parts.append(f"with {self.files_per_partition} files per partition") + elif self.blocksize: + parts.append(f"with target blocksize {self.blocksize}") + return ", ".join(parts) diff --git a/nemo_curator/stages/audio/io/materialize.py b/nemo_curator/stages/audio/io/materialize.py new file mode 100644 index 0000000000..4fff391a96 --- /dev/null +++ b/nemo_curator/stages/audio/io/materialize.py @@ -0,0 +1,235 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import hashlib +import io +import os +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import soundfile + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import AudioTask +from nemo_curator.tasks.audio_task import build_audio_sample_key +from nemo_curator.utils.remote_io import basename_from_path + + +def _path_suffix(path: str, fallback: str = ".bin") -> str: + suffix = Path(basename_from_path(path)).suffix + return suffix if suffix else fallback + + +@dataclass +class BaseAudioMaterializeStage(ProcessingStage[AudioTask, AudioTask]): + audio_filepath_key: str = "audio_filepath" + manifest_audio_filepath_key: str = "_manifest_audio_filepath" + temporary_audio_key: str = "_temporary_audio_path" + materialization_mode_key: str = "_materialization_mode" + materialized_field_name_key: str = "_materialized_audio_field" + offset_key: str = "offset" + duration_key: str = "duration" + temp_dir: str | None = None + materialization_dir: str | None = None + segment_if_offset_present: bool = True + name: str = "base_materialize_audio" + + def process(self, task: AudioTask) -> AudioTask: + msg = f"{self.__class__.__name__} only supports process_batch" + raise NotImplementedError(msg) + + def _materialize_tasks_from_bytes( + self, + tasks: list[AudioTask], + raw_audio: bytes, + source_name: str, + *, + reference_field: str, + output_field: str, + ) -> None: + field_paths = (reference_field, output_field) + segment_tasks: list[AudioTask] = [] + member_tasks_only: list[AudioTask] = [] + for task in tasks: + if self._should_segment(task, source_name, reference_field=field_paths[0]): + segment_tasks.append(task) + else: + member_tasks_only.append(task) + + decoded_audio: tuple[Any, int] | None = None + if segment_tasks: + decoded_audio = soundfile.read(io.BytesIO(raw_audio), dtype="float32") + + for task in member_tasks_only: + self._materialize_task( + task, + raw_audio, + source_name, + field_paths=field_paths, + ) + for task in segment_tasks: + self._materialize_task( + task, + raw_audio, + source_name, + field_paths=field_paths, + decoded_audio=decoded_audio, + ) + + def _materialize_task( + self, + task: AudioTask, + raw_audio: bytes, + source_name: str, + *, + field_paths: tuple[str, str], + decoded_audio: tuple[Any, int] | None = None, + ) -> None: + reference_field, output_field = field_paths + should_segment = self._should_segment(task, source_name, reference_field=reference_field) + output_path, is_temporary = self._create_output_path( + task, + output_field=output_field, + suffix=".wav" if should_segment else _path_suffix(source_name), + ) + if should_segment: + self._write_segmented_audio(task, raw_audio, output_path, decoded_audio=decoded_audio) + materialization_mode = "segment" + else: + output_path.write_bytes(raw_audio) + materialization_mode = "member" + + if output_field == self.audio_filepath_key: + task.data.setdefault(self.manifest_audio_filepath_key, task.data.get(self.audio_filepath_key)) + task.data[output_field] = output_path.as_posix() + task.data[self.materialized_field_name_key] = output_field + if is_temporary: + task.data[self.temporary_audio_key] = output_path.as_posix() + else: + task.data.pop(self.temporary_audio_key, None) + task.data[self.materialization_mode_key] = materialization_mode + + def _should_segment(self, task: AudioTask, source_name: str, *, reference_field: str) -> bool: + if not self.segment_if_offset_present: + return False + original_path = str(task.data.get(reference_field, "") or "") + offset = float(task.data.get(self.offset_key, 0.0) or 0.0) + return source_name != original_path or offset > 0.0 or task.data.get(self.duration_key) is not None + + def _create_temp_path(self, *, suffix: str) -> Path: + target_dir = Path(self.temp_dir) if self.temp_dir is not None else Path(tempfile.gettempdir()) + target_dir.mkdir(parents=True, exist_ok=True) + fd, path = tempfile.mkstemp(prefix="nemo_curator_materialized_audio_", suffix=suffix, dir=target_dir) + os.close(fd) + return Path(path) + + def _get_sample_key(self, task: AudioTask) -> str: + if task.sample_key: + return task.sample_key + task.sample_key = build_audio_sample_key(task.data, dataset_name=task.dataset_name) + return task.sample_key + + def _create_output_path(self, task: AudioTask, *, output_field: str, suffix: str) -> tuple[Path, bool]: + if self.materialization_dir is None: + return self._create_temp_path(suffix=suffix), True + + sample_basis = self._get_sample_key(task) + if output_field != self.audio_filepath_key: + sample_basis = f"{sample_basis}:{output_field}" + sample_hash = hashlib.sha256(sample_basis.encode("utf-8")).hexdigest() + target_dir = Path(self.materialization_dir) / sample_hash[:2] + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir / f"{sample_hash}{suffix}", False + + def _write_segmented_audio( + self, + task: AudioTask, + raw_audio: bytes, + output_path: Path, + *, + decoded_audio: tuple[Any, int] | None = None, + ) -> None: + offset = float(task.data.get(self.offset_key, 0.0) or 0.0) + duration = task.data.get(self.duration_key) + if duration is not None and float(duration) <= 0.0: + msg = f"Duration must be greater than 0 for segmented audio, got {duration!r}" + raise RuntimeError(msg) + if decoded_audio is None: + waveform, sample_rate = soundfile.read(io.BytesIO(raw_audio), dtype="float32") + else: + waveform, sample_rate = decoded_audio + start = max(round(offset * sample_rate), 0) + member_name = str(task.data.get(self.audio_filepath_key, "unknown")) + if start >= waveform.shape[0]: + msg = f"Offset {offset}s exceeds audio length for source '{member_name}'" + raise RuntimeError(msg) + end = waveform.shape[0] + if duration is not None: + end = min(start + round(float(duration) * sample_rate), waveform.shape[0]) + soundfile.write(output_path.as_posix(), waveform[start:end], sample_rate) + + +@dataclass +class CleanupTemporaryAudioStage(ProcessingStage[AudioTask, AudioTask]): + temporary_audio_key: str = "_temporary_audio_path" + audio_filepath_key: str = "audio_filepath" + manifest_audio_filepath_key: str = "_manifest_audio_filepath" + materialization_mode_key: str = "_materialization_mode" + materialized_field_name_key: str = "_materialized_audio_field" + restore_manifest_audio_filepath: bool = True + drop_temporary_metadata: bool = True + drop_materialized_field: bool = True + ignore_missing: bool = True + name: str = "cleanup_temporary_audio" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key] + + def process(self, task: AudioTask) -> AudioTask: + temp_path = task.data.get(self.temporary_audio_key) + if temp_path: + try: + os.unlink(temp_path) + except FileNotFoundError: + if not self.ignore_missing: + raise + + materialized_field_name = task.data.get(self.materialized_field_name_key) + if ( + self.restore_manifest_audio_filepath + and self.manifest_audio_filepath_key in task.data + and materialized_field_name in {None, self.audio_filepath_key} + ): + task.data[self.audio_filepath_key] = task.data[self.manifest_audio_filepath_key] + + if self.drop_temporary_metadata: + if ( + self.drop_materialized_field + and isinstance(materialized_field_name, str) + and materialized_field_name != self.audio_filepath_key + ): + task.data.pop(materialized_field_name, None) + task.data.pop(self.temporary_audio_key, None) + task.data.pop(self.materialization_mode_key, None) + task.data.pop(self.manifest_audio_filepath_key, None) + task.data.pop(self.materialized_field_name_key, None) + + return task diff --git a/nemo_curator/stages/audio/io/tarred.py b/nemo_curator/stages/audio/io/tarred.py new file mode 100644 index 0000000000..c45d32c473 --- /dev/null +++ b/nemo_curator/stages/audio/io/tarred.py @@ -0,0 +1,368 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import re +import tarfile +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Literal + +from fsspec.core import url_to_fs +from loguru import logger + +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.audio.io.materialize import BaseAudioMaterializeStage +from nemo_curator.stages.base import CompositeStage, ProcessingStage +from nemo_curator.tasks import AudioTask, FileGroupTask, _EmptyTask +from nemo_curator.tasks.audio_task import build_audio_sample_key +from nemo_curator.utils.file_utils import infer_dataset_name_from_path +from nemo_curator.utils.remote_io import ( + PipeStream, + expand_sharded_paths, + iter_tar_member_names, + open_binary_stream, + open_text_stream, + resolve_transport, +) + +_MANIFEST_SHARD_PATTERN = re.compile(r"manifest[^/\s]*_(\d+)[^/\s]*\.(?:json|jsonl)(?:\.[^/\s]+)?") +_TAR_SHARD_PATTERN = re.compile(r"audio[^/\s]*_(\d+)[^/\s]*\.tar(?:\.[^/\s]+)?") +_OFFSET_MEMBER_PATTERN = re.compile(r"^(?P.+?)(?P-sub\d+)(?P\.[^.\\/]+)?$") +_SKIP_MISSING_ENTRIES_METADATA_KEY = "_tarred_skip_missing_entries" + +_PipeStream = PipeStream +_open_binary_stream = open_binary_stream +_open_text_stream = open_text_stream +_iter_tar_member_names = iter_tar_member_names + + +def _extract_shard_id(path: str, kind: Literal["manifest", "tar"]) -> int: + pattern = _MANIFEST_SHARD_PATTERN if kind == "manifest" else _TAR_SHARD_PATTERN + match = pattern.search(path) + if match is None: + msg = f"Cannot determine shard id from {kind} path/specifier: {path}" + raise ValueError(msg) + return int(match.group(1)) + + +def _normalize_tar_member(audio_filepath: str) -> str: + match = _OFFSET_MEMBER_PATTERN.match(audio_filepath) + if match is None: + return audio_filepath + stem = match.group("stem") + ext = match.group("ext") or "" + return f"{stem}{ext}" + + +def _partition_paths(paths: list[str], files_per_partition: int) -> list[list[str]]: + return [paths[i : i + files_per_partition] for i in range(0, len(paths), files_per_partition)] + + +def _dataset_name_from_path(path: str) -> str: + if path.startswith("pipe:"): + return "dataset" + try: + return infer_dataset_name_from_path(path) + except Exception: # noqa: BLE001 + return "dataset" + + +def _should_validate_tar_members(path: str, transport: Literal["auto", "fsspec", "pipe"]) -> bool: + if resolve_transport(path, transport) == "pipe": + return False + fs, _ = url_to_fs(path) + protocol = fs.protocol[0] if isinstance(fs.protocol, (list, tuple)) else fs.protocol + return protocol in {None, "file"} + + +@dataclass +class TarredAudioManifestPartitionStage(ProcessingStage[_EmptyTask, FileGroupTask]): + manifest_paths: str | list[str] + files_per_partition: int = 1 + limit: int | None = None + name: str = "tarred_audio_manifest_partitioning" + + def __post_init__(self) -> None: + if self.files_per_partition <= 0: + msg = "files_per_partition must be positive" + raise ValueError(msg) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def ray_stage_spec(self) -> dict[str, Any]: + return {RayStageSpecKeys.IS_FANOUT_STAGE: True} + + def process(self, _: _EmptyTask) -> list[FileGroupTask]: + manifest_files = expand_sharded_paths(self.manifest_paths) + if not manifest_files: + return [] + + partitions = _partition_paths(manifest_files, self.files_per_partition) + dataset_name = _dataset_name_from_path(manifest_files[0]) + tasks: list[FileGroupTask] = [] + for i, file_group in enumerate(partitions): + if self.limit is not None and len(tasks) >= self.limit: + break + tasks.append( + FileGroupTask( + task_id=f"manifest_group_{i}", + dataset_name=dataset_name, + data=file_group, + _metadata={ + "partition_index": i, + "total_partitions": len(partitions), + "source_files": file_group, + }, + ) + ) + return tasks + + +@dataclass +class TarredAudioManifestReaderStage(ProcessingStage[FileGroupTask, AudioTask]): + tar_paths: str | list[str] + storage_options: dict[str, Any] | None = None + transport: Literal["auto", "fsspec", "pipe"] = "auto" + audio_filepath_key: str = "audio_filepath" + tar_path_key: str = "_tar_path" + tar_member_key: str = "_tar_member" + shard_id_key: str = "_shard_id" + manifest_path_key: str = "_manifest_path" + source_type_key: str = "_audio_source_type" + skip_missing_entries_metadata_key: str = _SKIP_MISSING_ENTRIES_METADATA_KEY + skip_missing_entries: bool = False + name: str = "tarred_audio_manifest_reader" + + def __post_init__(self) -> None: + expanded_tar_paths = expand_sharded_paths(self.tar_paths) + self._shard_id_to_tar_path: dict[int, str] = {} + for tar_path in expanded_tar_paths: + shard_id = _extract_shard_id(tar_path, "tar") + if shard_id in self._shard_id_to_tar_path: + msg = f"Duplicate tar shard id {shard_id} for paths: {tar_path} and {self._shard_id_to_tar_path[shard_id]}" + raise ValueError(msg) + self._shard_id_to_tar_path[shard_id] = tar_path + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return ["sample_key"], [ + self.audio_filepath_key, + self.tar_path_key, + self.tar_member_key, + self.shard_id_key, + self.manifest_path_key, + self.source_type_key, + ] + + def ray_stage_spec(self) -> dict[str, Any]: + return {RayStageSpecKeys.IS_FANOUT_STAGE: True} + + def process(self, task: FileGroupTask) -> list[AudioTask]: + results: list[AudioTask] = [] + for manifest_index, manifest_path in enumerate(task.data): + shard_id = _extract_shard_id(manifest_path, "manifest") + if shard_id not in self._shard_id_to_tar_path: + msg = f"No tar shard found for manifest shard {shard_id}: {manifest_path}" + raise RuntimeError(msg) + tar_path = self._shard_id_to_tar_path[shard_id] + tar_members = None + if _should_validate_tar_members(tar_path, self.transport): + tar_members = set( + _iter_tar_member_names( + tar_path, + storage_options=self.storage_options, + transport=self.transport, + ) + ) + with _open_text_stream( + manifest_path, + storage_options=self.storage_options, + transport=self.transport, + ) as fin: + for entry_index, line in enumerate(fin): + if not line.strip(): + continue + entry = json.loads(line) + manifest_audio_path = entry[self.audio_filepath_key] + tar_member = _normalize_tar_member(manifest_audio_path) + if tar_members is not None and tar_member not in tar_members: + msg = ( + f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " + f"Cannot locate tar member '{tar_member}' referenced by '{manifest_audio_path}'" + ) + if self.skip_missing_entries: + logger.warning(msg) + continue + raise RuntimeError(msg) + + item = dict(entry) + item[self.tar_path_key] = tar_path + item[self.tar_member_key] = tar_member + item[self.shard_id_key] = shard_id + item[self.manifest_path_key] = manifest_path + item[self.source_type_key] = "tarred" + results.append( + AudioTask( + task_id=f"{task.task_id}_{manifest_index}_{entry_index}", + dataset_name=task.dataset_name, + data=item, + sample_key=build_audio_sample_key(item, dataset_name=task.dataset_name), + _metadata={ + **task._metadata, + self.skip_missing_entries_metadata_key: self.skip_missing_entries, + }, + _stage_perf=list(task._stage_perf), + ) + ) + return results + + +@dataclass +class TarredAudioManifestReader(CompositeStage[_EmptyTask, AudioTask]): + manifest_paths: str | list[str] + tar_paths: str | list[str] + files_per_partition: int = 1 + limit: int | None = None + storage_options: dict[str, Any] | None = None + transport: Literal["auto", "fsspec", "pipe"] = "auto" + skip_missing_entries: bool = False + name: str = "tarred_audio_manifest_reader" + + def __post_init__(self) -> None: + super().__init__() + if not self.manifest_paths: + msg = "manifest_paths is required for TarredAudioManifestReader" + raise ValueError(msg) + if not self.tar_paths: + msg = "tar_paths is required for TarredAudioManifestReader" + raise ValueError(msg) + + def decompose(self) -> list[ProcessingStage]: + return [ + TarredAudioManifestPartitionStage( + manifest_paths=self.manifest_paths, + files_per_partition=self.files_per_partition, + limit=self.limit, + ), + TarredAudioManifestReaderStage( + tar_paths=self.tar_paths, + storage_options=self.storage_options, + transport=self.transport, + skip_missing_entries=self.skip_missing_entries, + ), + ] + + def get_description(self) -> str: + return ( + f"Read tarred audio manifests from {self.manifest_paths} and match shards against {self.tar_paths}" + ) + + +@dataclass +class MaterializeTarredAudioStage(BaseAudioMaterializeStage): + tar_path_key: str = "_tar_path" + tar_member_key: str = "_tar_member" + transport: Literal["auto", "fsspec", "pipe"] = "auto" + storage_options: dict[str, Any] | None = None + skip_missing_entries_metadata_key: str = _SKIP_MISSING_ENTRIES_METADATA_KEY + name: str = "materialize_tarred_audio" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.tar_path_key, self.tar_member_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [ + self.audio_filepath_key, + self.manifest_audio_filepath_key, + self.temporary_audio_key, + self.materialization_mode_key, + self.materialized_field_name_key, + ] + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + if len(tasks) == 0: + return [] + + for task in tasks: + if not self.validate_input(task): + msg = f"Task {task!s} failed validation for stage {self}" + raise ValueError(msg) + + grouped_tasks: dict[str, dict[str, list[AudioTask]]] = defaultdict(lambda: defaultdict(list)) + for task in tasks: + grouped_tasks[task.data[self.tar_path_key]][task.data[self.tar_member_key]].append(task) + + skipped_task_ids: set[str] = set() + for tar_path, member_tasks in grouped_tasks.items(): + skipped_task_ids.update(self._materialize_from_tar(tar_path, member_tasks)) + + return [task for task in tasks if task._uuid not in skipped_task_ids] + + def _materialize_from_tar(self, tar_path: str, member_tasks: dict[str, list[AudioTask]]) -> set[str]: + remaining = set(member_tasks) + with ( + _open_binary_stream( + tar_path, + storage_options=self.storage_options, + transport=self.transport, + allow_sigpipe=True, + ) as stream, + tarfile.open(fileobj=stream, mode="r|*") as tar, + ): + for tar_info in tar: + if not tar_info.isfile() or tar_info.name not in member_tasks: + continue + extracted = tar.extractfile(tar_info) + if extracted is None: + continue + raw_audio = extracted.read() + self._materialize_tasks_from_bytes( + member_tasks[tar_info.name], + raw_audio, + tar_info.name, + reference_field=self.audio_filepath_key, + output_field=self.audio_filepath_key, + ) + remaining.discard(tar_info.name) + if not remaining: + break + + skipped_task_ids: set[str] = set() + if remaining: + missing_non_skippable: list[str] = [] + for member_name in sorted(remaining): + tasks_for_member = member_tasks[member_name] + if all(bool(task._metadata.get(self.skip_missing_entries_metadata_key, False)) for task in tasks_for_member): + logger.warning( + "Skipping missing tar member '{}' from tar shard '{}' for {} task(s)", + member_name, + tar_path, + len(tasks_for_member), + ) + skipped_task_ids.update(task._uuid for task in tasks_for_member) + else: + missing_non_skippable.append(member_name) + if missing_non_skippable: + msg = f"Failed to materialize tar members {missing_non_skippable} from tar shard '{tar_path}'" + raise RuntimeError(msg) + return skipped_task_ids diff --git a/nemo_curator/stages/file_io/__init__.py b/nemo_curator/stages/file_io/__init__.py new file mode 100644 index 0000000000..adf2ec5183 --- /dev/null +++ b/nemo_curator/stages/file_io/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.file_io.files import ( + DeleteFilesStage, + MaterializeFilesStage, + UploadFilesStage, + UploadManifestStage, +) + +__all__ = [ + "DeleteFilesStage", + "MaterializeFilesStage", + "UploadFilesStage", + "UploadManifestStage", +] diff --git a/nemo_curator/stages/file_io/files.py b/nemo_curator/stages/file_io/files.py new file mode 100644 index 0000000000..a2be112d8c --- /dev/null +++ b/nemo_curator/stages/file_io/files.py @@ -0,0 +1,358 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import hashlib +import json +import os +import posixpath +import tempfile +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Literal + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import FileGroupTask, Task +from nemo_curator.utils.remote_io import ( + basename_from_path, + build_remote_uri, + copy_path, + read_binary, + remove_path, +) + +DictTask = Task[dict[str, Any]] + + +def _resolve_task_storage_options( + task: Task[Any], + *, + metadata_key: str, + stage_options: dict[str, Any] | None, +) -> dict[str, Any]: + task_options = task._metadata.get(metadata_key) + if isinstance(task_options, dict) and task_options: + return task_options + return stage_options or {} + + +def _split_field_path(field_path: str) -> list[str]: + parts = [part for part in field_path.split(".") if part] + if not parts: + msg = "field path must not be empty" + raise ValueError(msg) + return parts + + +def _task_data_as_dict(task: DictTask) -> dict[str, Any]: + if not isinstance(task.data, dict): + msg = f"{type(task).__name__} must have dict-backed data for file field stages" + raise TypeError(msg) + return task.data + + +def _resolve_field_path(data: dict[str, Any], field_path: str) -> Any: # noqa: ANN401 + current: Any = data + traversed: list[str] = [] + for part in _split_field_path(field_path): + traversed.append(part) + if not isinstance(current, dict): + msg = f"Field path '{field_path}' is not addressable past '{'.'.join(traversed[:-1])}'" + raise TypeError(msg) + if part not in current: + msg = f"Field path '{field_path}' is missing key '{part}'" + raise KeyError(msg) + current = current[part] + return current + + +def _set_field_path(data: dict[str, Any], field_path: str, value: object) -> None: + parts = _split_field_path(field_path) + current: dict[str, Any] = data + traversed: list[str] = [] + for part in parts[:-1]: + traversed.append(part) + next_value = current.get(part) + if next_value is None: + next_value = {} + current[part] = next_value + elif not isinstance(next_value, dict): + msg = f"Field path '{field_path}' cannot be created past '{'.'.join(traversed)}'" + raise TypeError(msg) + current = next_value + current[parts[-1]] = value + + +def _delete_field_path(data: dict[str, Any], field_path: str) -> None: + parts = _split_field_path(field_path) + current: dict[str, Any] = data + for part in parts[:-1]: + next_value = current.get(part) + if not isinstance(next_value, dict): + msg = f"Field path '{field_path}' is not addressable past '{part}'" + raise TypeError(msg) + current = next_value + if parts[-1] not in current: + msg = f"Field path '{field_path}' is missing key '{parts[-1]}'" + raise KeyError(msg) + del current[parts[-1]] + + +def _resolved_path_value(data: dict[str, Any], field_path: str) -> str: + value = _resolve_field_path(data, field_path) + if not isinstance(value, str): + msg = f"Field path '{field_path}' must resolve to a string path or URI" + raise TypeError(msg) + normalized = value.strip() + if not normalized: + msg = f"Field path '{field_path}' resolved to an empty string" + raise ValueError(msg) + return normalized + + +def _path_suffix(path: str, fallback: str = ".bin") -> str: + suffix = posixpath.splitext(basename_from_path(path))[1] + return suffix if suffix else fallback + + +@dataclass +class MaterializeFilesStage(ProcessingStage[DictTask, DictTask]): + source_field_path: str + output_field_path: str + temp_dir: str | None = None + materialization_dir: str | None = None + storage_options: dict[str, Any] | None = None + transport: Literal["auto", "fsspec", "pipe"] = "auto" + source_storage_metadata_key: str = "source_storage_options" + name: str = "materialize_files" + + def __post_init__(self) -> None: + _split_field_path(self.source_field_path) + _split_field_path(self.output_field_path) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: DictTask) -> DictTask: + msg = "MaterializeFilesStage only supports process_batch" + raise NotImplementedError(msg) + + def process_batch(self, tasks: list[DictTask]) -> list[DictTask]: + if len(tasks) == 0: + return [] + + grouped_tasks: dict[tuple[str, str], list[DictTask]] = defaultdict(list) + source_options_by_group: dict[tuple[str, str], dict[str, Any]] = {} + for task in tasks: + data = _task_data_as_dict(task) + source_path = _resolved_path_value(data, self.source_field_path) + source_storage_options = _resolve_task_storage_options( + task, + metadata_key=self.source_storage_metadata_key, + stage_options=self.storage_options, + ) + group_key = (source_path, json.dumps(source_storage_options, sort_keys=True, default=str)) + grouped_tasks[group_key].append(task) + source_options_by_group[group_key] = source_storage_options + + for group_key, tasks_for_source in grouped_tasks.items(): + source_path, _storage_key = group_key + raw_bytes = read_binary( + source_path, + storage_options=source_options_by_group[group_key], + transport=self.transport, + allow_sigpipe=True, + ) + shared_output_path: Path | None = None + if self.materialization_dir is not None: + shared_output_path = self._create_output_path(source_path) + shared_output_path.write_bytes(raw_bytes) + for task in tasks_for_source: + output_path = shared_output_path if shared_output_path is not None else self._create_output_path(source_path) + if shared_output_path is None: + output_path.write_bytes(raw_bytes) + _set_field_path(_task_data_as_dict(task), self.output_field_path, output_path.as_posix()) + + return tasks + + def _create_output_path(self, source_path: str) -> Path: + suffix = _path_suffix(source_path) + if self.materialization_dir is None: + return self._create_temp_path(suffix=suffix) + + materialization_identity = { + "source_path": source_path, + "output_field_path": self.output_field_path, + } + identity_json = json.dumps(materialization_identity, sort_keys=True, separators=(",", ":")) + source_hash = hashlib.sha256(identity_json.encode("utf-8")).hexdigest() + target_dir = Path(self.materialization_dir) / source_hash[:2] + target_dir.mkdir(parents=True, exist_ok=True) + return target_dir / f"{source_hash}{suffix}" + + def _create_temp_path(self, *, suffix: str) -> Path: + target_dir = Path(self.temp_dir) if self.temp_dir is not None else Path(tempfile.gettempdir()) + target_dir.mkdir(parents=True, exist_ok=True) + fd, path = tempfile.mkstemp(prefix="nemo_curator_materialized_file_", suffix=suffix, dir=target_dir) + os.close(fd) + return Path(path) + + +@dataclass +class UploadFilesStage(ProcessingStage[DictTask, DictTask]): + source_field_path: str + output_field_path: str + bucket: str + protocol: str = "s3" + key_prefix: str = "" + key_field_path: str | None = None + storage_options: dict[str, Any] | None = None + source_storage_options: dict[str, Any] | None = None + source_transport: Literal["auto", "fsspec", "pipe"] = "auto" + source_storage_metadata_key: str = "source_storage_options" + name: str = "upload_files" + + def __post_init__(self) -> None: + _split_field_path(self.source_field_path) + _split_field_path(self.output_field_path) + if not self.bucket.strip("/"): + msg = "bucket is required for UploadFilesStage" + raise ValueError(msg) + if self.key_field_path is not None: + _split_field_path(self.key_field_path) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: DictTask) -> DictTask: + data = _task_data_as_dict(task) + source_path = _resolved_path_value(data, self.source_field_path) + destination_uri = build_remote_uri( + protocol=self.protocol, + bucket=self.bucket, + key=self._build_object_key(data, source_path), + ) + copy_path( + source_path, + destination_uri, + source_storage_options=_resolve_task_storage_options( + task, + metadata_key=self.source_storage_metadata_key, + stage_options=self.source_storage_options, + ), + destination_storage_options=self.storage_options, + source_transport=self.source_transport, + ) + _set_field_path(data, self.output_field_path, destination_uri) + return task + + def _build_object_key(self, data: dict[str, Any], source_path: str) -> str: + if self.key_field_path is not None: + key_name = _resolved_path_value(data, self.key_field_path).strip("/") + else: + key_name = basename_from_path(source_path) + if self.key_prefix: + return posixpath.join(self.key_prefix.strip("/"), key_name) + return key_name + + +@dataclass +class DeleteFilesStage(ProcessingStage[DictTask, DictTask]): + source_field_path: str + storage_options: dict[str, Any] | None = None + ignore_missing: bool = True + source_storage_metadata_key: str = "source_storage_options" + name: str = "delete_files" + + def __post_init__(self) -> None: + _split_field_path(self.source_field_path) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: DictTask) -> DictTask: + data = _task_data_as_dict(task) + remove_path( + _resolved_path_value(data, self.source_field_path), + storage_options=_resolve_task_storage_options( + task, + metadata_key=self.source_storage_metadata_key, + stage_options=self.storage_options, + ), + ignore_missing=self.ignore_missing, + ) + _delete_field_path(data, self.source_field_path) + return task + + +@dataclass +class UploadManifestStage(ProcessingStage[FileGroupTask, FileGroupTask]): + bucket: str + protocol: str = "s3" + key_prefix: str = "" + storage_options: dict[str, Any] | None = None + source_storage_options: dict[str, Any] | None = None + source_transport: Literal["auto", "fsspec", "pipe"] = "auto" + name: str = "upload_manifest" + + def __post_init__(self) -> None: + if not self.bucket.strip("/"): + msg = "bucket is required for UploadManifestStage" + raise ValueError(msg) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def process(self, task: FileGroupTask) -> FileGroupTask: + uploaded_paths: list[str] = [] + for source_path in task.data: + key_name = basename_from_path(source_path) + if self.key_prefix: + key_name = posixpath.join(self.key_prefix.strip("/"), key_name) + destination_uri = build_remote_uri(protocol=self.protocol, bucket=self.bucket, key=key_name) + copy_path( + source_path, + destination_uri, + source_storage_options=self.source_storage_options, + destination_storage_options=self.storage_options, + source_transport=self.source_transport, + ) + uploaded_paths.append(destination_uri) + + return FileGroupTask( + task_id=task.task_id, + dataset_name=task.dataset_name, + data=uploaded_paths, + _metadata={ + **task._metadata, + "uploaded_files": uploaded_paths, + "local_source_files": list(task.data), + }, + _stage_perf=task._stage_perf, + reader_config=task.reader_config, + ) diff --git a/nemo_curator/stages/text/io/reader/jsonl.py b/nemo_curator/stages/text/io/reader/jsonl.py index 24bb3dd0cf..c0bfba87ab 100644 --- a/nemo_curator/stages/text/io/reader/jsonl.py +++ b/nemo_curator/stages/text/io/reader/jsonl.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from dataclasses import dataclass, field from typing import Any, Literal import pandas as pd +import ray +from fsspec.core import url_to_fs from loguru import logger -from nemo_curator.stages.base import CompositeStage +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.base import CompositeStage, ProcessingStage from nemo_curator.stages.file_partitioning import FilePartitioningStage -from nemo_curator.tasks import DocumentBatch, _EmptyTask +from nemo_curator.tasks import AudioTask, DocumentBatch, FileGroupTask, _EmptyTask +from nemo_curator.tasks.audio_task import build_audio_sample_key from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, pandas_select_columns from .base import BaseReader @@ -81,12 +86,152 @@ def read_data( @dataclass -class JsonlReader(CompositeStage[_EmptyTask, DocumentBatch]): +class JsonlAudioReaderStage(ProcessingStage[FileGroupTask, AudioTask]): + """Stage that streams JSONL manifests and emits one ``AudioTask`` per line. + + Unlike ``JsonlReaderStage``, this stage avoids Pandas and fans out each JSONL + entry into an ``AudioTask``. This keeps audio manifests compatible with + downstream audio stages and avoids materializing nested metadata as a + ``DocumentBatch``. + """ + + fields: list[str] | None = None + read_kwargs: dict[str, Any] = field(default_factory=dict) + _generate_ids: bool = False + _assign_ids: bool = False + name: str = "jsonl_audio_reader" + + def __post_init__(self) -> None: + if self._generate_ids and self._assign_ids: + msg = "Cannot generate and assign IDs at the same time" + raise ValueError(msg) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + output_fields = list(self.fields or []) + if self._generate_ids or self._assign_ids: + from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR + + output_fields.append(CURATOR_DEDUP_ID_STR) + return ["sample_key"], output_fields + + def setup(self, _: Any = None) -> None: # noqa: ANN401 + if self._generate_ids or self._assign_ids: + from nemo_curator.stages.deduplication.id_generator import get_id_generator_actor + + try: + self.id_generator = get_id_generator_actor() + except ValueError: + msg = ( + "ID generator is required when self._generate_ids or self._assign_ids is True, " + "and the actor 'id_generator' does not exist. Please start the id_generator actor." + ) + raise RuntimeError(msg) from None + + def _apply_generated_ids(self, file_paths: list[str], tasks: list[AudioTask]) -> None: + from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR + + if any(CURATOR_DEDUP_ID_STR in task.data for task in tasks): + logger.warning(f"Column {CURATOR_DEDUP_ID_STR} already exists in {file_paths}, not generating new IDs") + return + + min_id = ray.get(self.id_generator.register_batch.remote(file_paths, len(tasks))) + for offset, task in enumerate(tasks): + task.data[CURATOR_DEDUP_ID_STR] = min_id + offset + + def _apply_assigned_ids(self, file_paths: list[str], tasks: list[AudioTask]) -> None: + from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR + + if any(CURATOR_DEDUP_ID_STR in task.data for task in tasks): + logger.warning(f"Column {CURATOR_DEDUP_ID_STR} already exists in {file_paths}, not re-assigning IDs") + return + + min_id, max_id = ray.get(self.id_generator.get_batch_range.remote(file_paths, None)) + assigned_count = max_id - min_id + 1 + task_count = len(tasks) + if assigned_count < task_count: + msg = ( + f"Assigned ID range for {file_paths} contains {assigned_count} IDs, but the audio JSONL reader " + f"produced {task_count} tasks. Ensure the batch was pre-registered with the number of non-blank " + "JSONL entries." + ) + raise RuntimeError(msg) + if assigned_count > task_count: + logger.warning( + "Assigned ID range for {} contains {} IDs, but the audio JSONL reader produced {} tasks " + "after skipping blank lines. Assigning the first {} IDs from the registered range.", + file_paths, + assigned_count, + task_count, + task_count, + ) + + for next_id, task in zip(range(min_id, max_id + 1), tasks, strict=False): + task.data[CURATOR_DEDUP_ID_STR] = next_id + + def process(self, task: FileGroupTask) -> list[AudioTask]: + """Read JSONL files line-by-line and return one ``AudioTask`` per entry.""" + read_kwargs = {} if self.read_kwargs is None else dict(self.read_kwargs) + if "lines" in read_kwargs and read_kwargs["lines"] is False: + msg = "lines=False is not supported for JSONL reader" + raise ValueError(msg) + read_kwargs.pop("lines", None) + + storage_options = read_kwargs.pop("storage_options", None) or {} + open_kwargs = { + key: read_kwargs.pop(key) + for key in ("compression", "encoding", "errors", "newline") + if key in read_kwargs + } + open_kwargs.setdefault("encoding", "utf-8") + + if read_kwargs: + logger.warning(f"Ignoring unsupported read_kwargs for audio JSONL reader: {sorted(read_kwargs.keys())}") + + results: list[AudioTask] = [] + for file_path in task.data: + fs, resolved = url_to_fs(file_path, **storage_options) + with fs.open(resolved, "r", **open_kwargs) as f: + for line in f: + if not line.strip(): + continue + raw_entry = json.loads(line) + sample_key = build_audio_sample_key(raw_entry, dataset_name=task.dataset_name) + entry = raw_entry + if self.fields is not None: + entry = {field: entry[field] for field in self.fields if field in entry} + results.append( + AudioTask( + task_id=f"{task.task_id}_{len(results)}", + dataset_name=task.dataset_name, + data=entry, + sample_key=sample_key, + _metadata=task._metadata, + _stage_perf=list(task._stage_perf), + ) + ) + + if results: + if self._generate_ids: + self._apply_generated_ids(task.data, results) + elif self._assign_ids: + self._apply_assigned_ids(task.data, results) + + return results + + def ray_stage_spec(self) -> dict[str, Any]: + return {RayStageSpecKeys.IS_FANOUT_STAGE: True} + + +@dataclass +class JsonlReader(CompositeStage[_EmptyTask, DocumentBatch | AudioTask]): """Composite stage for reading JSONL files. - This high-level stage decomposes into: - 1. FilePartitioningStage - partitions files into groups - 2. JsonlReaderStage - reads file groups into DocumentBatches + The output type depends on ``task_type``: + 1. ``document`` -> ``FilePartitioningStage`` + ``JsonlReaderStage`` -> ``DocumentBatch`` + 2. ``audio`` -> ``FilePartitioningStage`` + ``JsonlAudioReaderStage`` -> ``AudioTask`` """ file_paths: str | list[str] @@ -106,13 +251,13 @@ def __post_init__(self): if self.read_kwargs is not None: self.storage_options = self.read_kwargs.get("storage_options", {}) - def decompose(self) -> list[JsonlReaderStage]: + def decompose(self) -> list[ProcessingStage]: """Decompose into file partitioning and processing stages.""" - if self.task_type != "document": + if self.task_type not in {"document", "audio"}: msg = f"Converting DocumentBatch to {self.task_type} is not supported yet." raise NotImplementedError(msg) - return [ + stages: list[ProcessingStage] = [ FilePartitioningStage( file_paths=self.file_paths, files_per_partition=self.files_per_partition, @@ -121,15 +266,30 @@ def decompose(self) -> list[JsonlReaderStage]: storage_options=self.read_kwargs.get("storage_options", None) if self.read_kwargs is not None else None, - ), - JsonlReaderStage( - fields=self.fields, - read_kwargs=(self.read_kwargs or {}), - _generate_ids=self._generate_ids, - _assign_ids=self._assign_ids, - ), + ) ] + if self.task_type == "audio": + stages.append( + JsonlAudioReaderStage( + fields=self.fields, + read_kwargs=(self.read_kwargs or {}), + _generate_ids=self._generate_ids, + _assign_ids=self._assign_ids, + ) + ) + else: + stages.append( + JsonlReaderStage( + fields=self.fields, + read_kwargs=(self.read_kwargs or {}), + _generate_ids=self._generate_ids, + _assign_ids=self._assign_ids, + ) + ) + + return stages + def get_description(self) -> str: """Get a description of this composite stage.""" diff --git a/nemo_curator/tasks/audio_task.py b/nemo_curator/tasks/audio_task.py index 94e4e5f5a3..b4ffa6d95b 100644 --- a/nemo_curator/tasks/audio_task.py +++ b/nemo_curator/tasks/audio_task.py @@ -12,33 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License. +import hashlib +import json import os +from collections.abc import Mapping from dataclasses import dataclass, field +from typing import Any from loguru import logger from .tasks import Task +AUDIO_SAMPLE_KEY_FIELD = "sample_key" + class _AttrDict(dict): """Dict subclass exposing keys as attributes so ``hasattr`` works.""" - def __getattr__(self, key: str): + def __getattr__(self: "_AttrDict", key: str) -> object: try: return self[key] except KeyError: raise AttributeError(key) from None - def __setattr__(self, key: str, value: object) -> None: + def __setattr__(self: "_AttrDict", key: str, value: object) -> None: self[key] = value - def __delattr__(self, key: str): + def __delattr__(self: "_AttrDict", key: str) -> None: try: del self[key] except KeyError: raise AttributeError(key) from None +def _normalize_sample_key_value(value: Any) -> Any: # noqa: ANN401 + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, int): + return value + if isinstance(value, float): + return float(value) + if isinstance(value, str): + normalized = value.strip() + return normalized if normalized else None + return str(value) + + +def build_audio_sample_key( + data: Mapping[str, Any], + *, + dataset_name: str = "", + sample_key_field: str = AUDIO_SAMPLE_KEY_FIELD, +) -> str: + """Build a stable sample key for an audio entry. + + If the input already contains an explicit ``sample_key`` value, preserve it. + Otherwise derive a deterministic hash from the sample identity fields. + """ + + existing = _normalize_sample_key_value(data.get(sample_key_field)) + if existing is not None: + return str(existing) + + identity = { + "dataset_name": _normalize_sample_key_value(dataset_name), + "source_type": _normalize_sample_key_value(data.get("_audio_source_type")), + "audio_filepath": _normalize_sample_key_value(data.get("audio_filepath")), + "tar_path": _normalize_sample_key_value(data.get("_tar_path")), + "tar_member": _normalize_sample_key_value(data.get("_tar_member")), + "shard_id": _normalize_sample_key_value(data.get("_shard_id")), + "offset": _normalize_sample_key_value(data.get("offset")), + "duration": _normalize_sample_key_value(data.get("duration")), + } + identity_json = json.dumps(identity, sort_keys=True, separators=(",", ":"), ensure_ascii=True) + return hashlib.sha256(identity_json.encode("utf-8")).hexdigest() + + @dataclass class AudioTask(Task[dict]): """A single audio manifest entry. @@ -56,11 +107,18 @@ class AudioTask(Task[dict]): task_id: str = "" dataset_name: str = "" data: dict = field(default_factory=_AttrDict) + sample_key: str = "" filepath_key: str | None = None - def __post_init__(self): + def __post_init__(self) -> None: if not isinstance(self.data, _AttrDict): self.data = _AttrDict(self.data) + if not self.sample_key: + existing = self.data.get(AUDIO_SAMPLE_KEY_FIELD) + if isinstance(existing, str) and existing.strip(): + self.sample_key = existing.strip() + if self.sample_key: + self.data.setdefault(AUDIO_SAMPLE_KEY_FIELD, self.sample_key) @property def num_items(self) -> int: diff --git a/nemo_curator/utils/remote_io.py b/nemo_curator/utils/remote_io.py new file mode 100644 index 0000000000..891a8baad6 --- /dev/null +++ b/nemo_curator/utils/remote_io.py @@ -0,0 +1,252 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import io +import posixpath +import re +import shutil +import signal +import subprocess +import tarfile +from contextlib import contextmanager +from typing import IO, TYPE_CHECKING, Any, Literal + +import fsspec +from fsspec.core import url_to_fs + +if TYPE_CHECKING: + from collections.abc import Iterator + +_OP_CL_PATTERN = re.compile(r"_OP_(\d+)\.\.(\d+)_CL_") +_BRACE_RANGE_PATTERN = re.compile(r"\{(\d+)\.\.(\d+)\}") + + +class PipeStream: + def __init__(self, command: str, *, allow_sigpipe: bool = False): + self.command = command + self.allow_sigpipe = allow_sigpipe + self.process: subprocess.Popen[bytes] | None = None + + def __enter__(self) -> IO[bytes]: + self.process = subprocess.Popen( # noqa: S602 + self.command, + shell=True, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + if self.process.stdout is None: + msg = f"Failed to open pipe command stdout: {self.command}" + raise RuntimeError(msg) + return self.process.stdout + + def __exit__(self, exc_type, exc, tb) -> bool: # noqa: ANN001 + if self.process is None: + return False + if self.process.stdout is not None and not self.process.stdout.closed: + self.process.stdout.close() + stderr = b"" + if self.process.stderr is not None: + stderr = self.process.stderr.read() + self.process.stderr.close() + return_code = self.process.wait() + if exc_type is None and return_code != 0 and not self._is_allowed_sigpipe_return_code(return_code): + detail = stderr.decode("utf-8", errors="replace").strip() + msg = f"Pipe command failed with exit code {return_code}: {self.command}" + if detail: + msg += f"\n{detail}" + raise RuntimeError(msg) + return False + + def _is_allowed_sigpipe_return_code(self, return_code: int) -> bool: + if not self.allow_sigpipe or not hasattr(signal, "SIGPIPE"): + return False + sigpipe_value = signal.Signals.SIGPIPE.value + return return_code in {128 + sigpipe_value, -sigpipe_value} + + +def _expand_spec_string(spec: str) -> list[str]: + for pattern in (_OP_CL_PATTERN, _BRACE_RANGE_PATTERN): + match = pattern.search(spec) + if match is None: + continue + start_str, end_str = match.groups() + start = int(start_str) + end = int(end_str) + if end < start: + msg = f"Invalid shard range: start={start}, end={end}, spec={spec}" + raise ValueError(msg) + width = max(len(start_str), len(end_str)) + prefix = spec[: match.start()] + suffix = spec[match.end() :] + expanded = [f"{prefix}{value:0{width}d}{suffix}" for value in range(start, end + 1)] + results: list[str] = [] + for item in expanded: + results.extend(_expand_spec_string(item)) + return results + return [spec] + + +def expand_sharded_paths(paths: str | list[str]) -> list[str]: + if isinstance(paths, list): + results: list[str] = [] + for path in paths: + results.extend(_expand_spec_string(path)) + return results + return _expand_spec_string(paths) + + +def resolve_transport(path: str, transport: Literal["auto", "fsspec", "pipe"]) -> Literal["fsspec", "pipe"]: + if transport == "auto": + return "pipe" if path.startswith("pipe:") else "fsspec" + return transport + + +@contextmanager +def open_binary_stream( + path: str, + *, + storage_options: dict[str, Any] | None = None, + transport: Literal["auto", "fsspec", "pipe"] = "auto", + allow_sigpipe: bool = False, +) -> Iterator[IO[bytes]]: + resolved_transport = resolve_transport(path, transport) + if resolved_transport == "pipe": + command = path[len("pipe:") :].strip() if path.startswith("pipe:") else path + with PipeStream(command, allow_sigpipe=allow_sigpipe) as stream: + yield stream + else: + with fsspec.open(path, mode="rb", **(storage_options or {})) as stream: + yield stream + + +@contextmanager +def open_text_stream( + path: str, + *, + storage_options: dict[str, Any] | None = None, + transport: Literal["auto", "fsspec", "pipe"] = "auto", + encoding: str = "utf-8", +) -> Iterator[IO[str]]: + resolved_transport = resolve_transport(path, transport) + if resolved_transport == "pipe": + with open_binary_stream(path, storage_options=storage_options, transport=transport) as stream: + text_stream = io.TextIOWrapper(stream, encoding=encoding) + try: + yield text_stream + finally: + text_stream.detach() + else: + with fsspec.open(path, mode="rt", encoding=encoding, **(storage_options or {})) as stream: + yield stream + + +def read_binary( + path: str, + *, + storage_options: dict[str, Any] | None = None, + transport: Literal["auto", "fsspec", "pipe"] = "auto", + allow_sigpipe: bool = False, +) -> bytes: + with open_binary_stream( + path, + storage_options=storage_options, + transport=transport, + allow_sigpipe=allow_sigpipe, + ) as stream: + return stream.read() + + +def iter_tar_member_names( + tar_path: str, + *, + storage_options: dict[str, Any] | None = None, + transport: Literal["auto", "fsspec", "pipe"] = "auto", +) -> Iterator[str]: + with ( + open_binary_stream(tar_path, storage_options=storage_options, transport=transport) as stream, + tarfile.open(fileobj=stream, mode="r|*") as tar, + ): + for member in tar: + if member.isfile(): + yield member.name + + +def copy_path( + source_path: str, + destination_path: str, + *, + source_storage_options: dict[str, Any] | None = None, + destination_storage_options: dict[str, Any] | None = None, + source_transport: Literal["auto", "fsspec", "pipe"] = "auto", +) -> None: + if destination_path.startswith("pipe:"): + msg = f"Writing to pipe destinations is not supported: {destination_path}" + raise ValueError(msg) + + fs, resolved_destination = url_to_fs(destination_path, **(destination_storage_options or {})) + parent_dir = posixpath.dirname(resolved_destination) + if parent_dir: + fs.makedirs(parent_dir, exist_ok=True) + + with ( + open_binary_stream( + source_path, + storage_options=source_storage_options, + transport=source_transport, + allow_sigpipe=True, + ) as source_stream, + fs.open(resolved_destination, "wb") as destination_stream, + ): + shutil.copyfileobj(source_stream, destination_stream, length=1024 * 1024) + + +def remove_path( + path: str, + *, + storage_options: dict[str, Any] | None = None, + recursive: bool = False, + ignore_missing: bool = False, +) -> None: + if path.startswith("pipe:"): + msg = f"Deleting pipe destinations is not supported: {path}" + raise ValueError(msg) + + fs, resolved_path = url_to_fs(path, **(storage_options or {})) + try: + fs.rm(resolved_path, recursive=recursive) + except FileNotFoundError: + if not ignore_missing: + raise + + +def basename_from_path(path: str) -> str: + stripped = path[len("pipe:") :].strip() if path.startswith("pipe:") else path + stripped = stripped.rstrip("/") + if not stripped: + return "" + return stripped.rsplit("/", maxsplit=1)[-1] + + +def build_remote_uri(*, protocol: str, bucket: str, key: str = "") -> str: + normalized_bucket = bucket.strip("/") + if not normalized_bucket: + msg = "bucket is required to build a remote URI" + raise ValueError(msg) + normalized_key = key.strip("/") + if normalized_key: + return f"{protocol}://{normalized_bucket}/{normalized_key}" + return f"{protocol}://{normalized_bucket}" diff --git a/tests/stages/audio/io/test_convert.py b/tests/stages/audio/io/test_convert.py index 7ab13ddce6..69c6205cf7 100644 --- a/tests/stages/audio/io/test_convert.py +++ b/tests/stages/audio/io/test_convert.py @@ -85,3 +85,16 @@ def test_process_batch_single_task() -> None: assert len(result) == 1 assert len(result[0].data) == 1 assert result[0].data.iloc[0]["text"] == "hi" + + +def test_process_batch_serializes_constructor_sample_key() -> None: + task = AudioTask( + task_id="only", + dataset_name="ds", + data={"audio_filepath": "/x.wav", "text": "hi"}, + sample_key="sample-serialized", + ) + stage = AudioToDocumentStage() + result = stage.process_batch([task]) + + assert result[0].data.iloc[0]["sample_key"] == "sample-serialized" diff --git a/tests/stages/audio/io/test_tarred.py b/tests/stages/audio/io/test_tarred.py new file mode 100644 index 0000000000..3bbdd50302 --- /dev/null +++ b/tests/stages/audio/io/test_tarred.py @@ -0,0 +1,556 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import json +import tarfile +import wave +from dataclasses import dataclass +from pathlib import Path + +import pytest + +from nemo_curator.stages.audio.io.materialize import CleanupTemporaryAudioStage +from nemo_curator.stages.audio.io.tarred import ( + MaterializeTarredAudioStage, + TarredAudioManifestPartitionStage, + TarredAudioManifestReader, + TarredAudioManifestReaderStage, + _PipeStream, +) +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import AudioTask, FileGroupTask, _EmptyTask + + +def _write_tar(path: Path, members: dict[str, bytes]) -> None: + with tarfile.open(path, "w") as tar: + for name, content in members.items(): + info = tarfile.TarInfo(name=name) + info.size = len(content) + tar.addfile(info, io.BytesIO(content)) + + +def _make_wav_bytes(*, sample_rate: int = 16000, duration_sec: float = 1.0) -> bytes: + frames = int(sample_rate * duration_sec) + data = (b"\x00\x00" * frames) + buffer = io.BytesIO() + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(sample_rate) + wav_file.writeframes(data) + return buffer.getvalue() + + +def _make_file_group_task(paths: list[str]) -> FileGroupTask: + return FileGroupTask(task_id="group", dataset_name="dataset", data=paths) + + +@dataclass +class _PathConsumerStage(ProcessingStage[AudioTask, AudioTask]): + audio_filepath_key: str = "audio_filepath" + seen_exists_key: str = "_path_exists" + name: str = "path_consumer" + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.seen_exists_key] + + def process(self, task: AudioTask) -> AudioTask: + task.data[self.seen_exists_key] = Path(task.data[self.audio_filepath_key]).exists() + return task + + +class TestTarredAudioManifestReader: + def test_partition_stage_expands_op_cl_pattern(self, tmp_path: Path) -> None: + for shard_id in range(2): + (tmp_path / f"manifest_{shard_id}.json").write_text( + json.dumps({"audio_filepath": f"{shard_id}.wav", "text": f"text-{shard_id}"}) + "\n" + ) + + stage = TarredAudioManifestPartitionStage( + manifest_paths=str(tmp_path / "manifest__OP_0..1_CL_.json"), + files_per_partition=1, + ) + result = stage.process(_EmptyTask) + + assert len(result) == 2 + assert result[0].data == [str(tmp_path / "manifest_0.json")] + assert result[1].data == [str(tmp_path / "manifest_1.json")] + + def test_reader_maps_manifest_entries_to_tar_members(self, tmp_path: Path) -> None: + manifest = tmp_path / "manifest_0.json" + tar_path = tmp_path / "audio_0.tar" + manifest.write_text( + "\n".join( + [ + json.dumps({"audio_filepath": "a.wav", "text": "alpha"}), + json.dumps({"audio_filepath": "b.wav-sub1", "text": "beta", "offset": 0.0, "duration": 0.5}), + ] + ) + ) + _write_tar(tar_path, {"a.wav": b"a-bytes", "b.wav": _make_wav_bytes(duration_sec=1.0)}) + + stage = TarredAudioManifestReaderStage(tar_paths=str(tar_path)) + result = stage.process(_make_file_group_task([str(manifest)])) + + assert len(result) == 2 + assert result[0].data["_tar_path"] == str(tar_path) + assert result[0].data["_tar_member"] == "a.wav" + assert result[0].sample_key + assert result[1].data["_tar_member"] == "b.wav" + assert result[1].data["audio_filepath"] == "b.wav-sub1" + assert result[1].data["_audio_source_type"] == "tarred" + assert result[1].sample_key + assert result[0].sample_key != result[1].sample_key + + def test_reader_raises_when_manifest_entry_missing_in_tar(self, tmp_path: Path) -> None: + manifest = tmp_path / "manifest_0.json" + tar_path = tmp_path / "audio_0.tar" + manifest.write_text( + "\n".join( + [ + json.dumps({"audio_filepath": "a.wav"}), + json.dumps({"audio_filepath": "missing.wav"}), + ] + ) + ) + _write_tar(tar_path, {"a.wav": b"a-bytes"}) + + stage = TarredAudioManifestReaderStage(tar_paths=str(tar_path), skip_missing_entries=False) + + with pytest.raises(RuntimeError, match=r"Cannot locate tar member 'missing\.wav'"): + stage.process(_make_file_group_task([str(manifest)])) + + def test_reader_skips_missing_entries_when_enabled(self, tmp_path: Path) -> None: + manifest = tmp_path / "manifest_0.json" + tar_path = tmp_path / "audio_0.tar" + manifest.write_text( + "\n".join( + [ + json.dumps({"audio_filepath": "a.wav"}), + json.dumps({"audio_filepath": "missing.wav"}), + ] + ) + ) + _write_tar(tar_path, {"a.wav": b"a-bytes"}) + + stage = TarredAudioManifestReaderStage(tar_paths=str(tar_path), skip_missing_entries=True) + result = stage.process(_make_file_group_task([str(manifest)])) + + assert len(result) == 1 + assert result[0].data["audio_filepath"] == "a.wav" + + def test_composite_decomposes_and_reads_sharded_specs(self, tmp_path: Path) -> None: + manifest_0 = tmp_path / "manifest_0.json" + manifest_1 = tmp_path / "manifest_1.json" + tar_0 = tmp_path / "audio_0.tar" + tar_1 = tmp_path / "audio_1.tar" + + manifest_0.write_text(json.dumps({"audio_filepath": "a.wav", "text": "alpha"}) + "\n") + manifest_1.write_text(json.dumps({"audio_filepath": "b.wav", "text": "beta"}) + "\n") + _write_tar(tar_0, {"a.wav": b"a-bytes"}) + _write_tar(tar_1, {"b.wav": b"b-bytes"}) + + reader = TarredAudioManifestReader( + manifest_paths=str(tmp_path / "manifest__OP_0..1_CL_.json"), + tar_paths=str(tmp_path / "audio__OP_0..1_CL_.tar"), + files_per_partition=1, + ) + partition_stage, reader_stage = reader.decompose() + + file_tasks = partition_stage.process(_EmptyTask) + results = reader_stage.process(file_tasks[0]) + + assert len(file_tasks) == 2 + assert len(results) == 1 + assert results[0].data["_tar_path"] == str(tar_0) + + def test_composite_propagates_limit_to_partition_stage(self, tmp_path: Path) -> None: + for shard_id in range(2): + (tmp_path / f"manifest_{shard_id}.json").write_text( + json.dumps({"audio_filepath": f"{shard_id}.wav", "text": f"text-{shard_id}"}) + "\n" + ) + _write_tar(tmp_path / f"audio_{shard_id}.tar", {f"{shard_id}.wav": b"bytes"}) + + reader = TarredAudioManifestReader( + manifest_paths=str(tmp_path / "manifest__OP_0..1_CL_.json"), + tar_paths=str(tmp_path / "audio__OP_0..1_CL_.tar"), + files_per_partition=1, + limit=1, + ) + partition_stage, _reader_stage = reader.decompose() + file_tasks = partition_stage.process(_EmptyTask) + + assert len(file_tasks) == 1 + + +class TestTarredAudioMaterialization: + def test_pipe_stream_allows_sigpipe_when_opted_in(self) -> None: + class _FakeProcess: + def __init__(self, return_code: int): + self.stdout = io.BytesIO(b"") + self.stderr = io.BytesIO(b"") + self._return_code = return_code + + def wait(self) -> int: + return self._return_code + + pipe_stream = _PipeStream("dummy", allow_sigpipe=True) + pipe_stream.process = _FakeProcess(return_code=141) # type: ignore[assignment] + + assert pipe_stream.__exit__(None, None, None) is False + + def test_pipe_stream_raises_for_sigpipe_by_default(self) -> None: + class _FakeProcess: + def __init__(self, return_code: int): + self.stdout = io.BytesIO(b"") + self.stderr = io.BytesIO(b"") + self._return_code = return_code + + def wait(self) -> int: + return self._return_code + + pipe_stream = _PipeStream("dummy") + pipe_stream.process = _FakeProcess(return_code=141) # type: ignore[assignment] + + with pytest.raises(RuntimeError, match="Pipe command failed with exit code 141"): + pipe_stream.__exit__(None, None, None) + + def test_materialize_and_cleanup_roundtrip(self, tmp_path: Path) -> None: + tar_path = tmp_path / "audio_0.tar" + raw_audio = b"test-bytes" + _write_tar(tar_path, {"sample.wav": raw_audio}) + + task = AudioTask( + task_id="t1", + dataset_name="ds", + data={ + "audio_filepath": "sample.wav", + "_tar_path": str(tar_path), + "_tar_member": "sample.wav", + }, + ) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp")) + [materialized] = materialize.process_batch([task]) + + temp_path = Path(materialized.data["_temporary_audio_path"]) + assert temp_path.exists() + assert temp_path.read_bytes() == raw_audio + assert materialized.data["audio_filepath"] == temp_path.as_posix() + assert materialized.data["_manifest_audio_filepath"] == "sample.wav" + assert materialized.data["_materialization_mode"] == "member" + + cleanup = CleanupTemporaryAudioStage() + cleaned = cleanup.process(materialized) + + assert not temp_path.exists() + assert cleaned.data["audio_filepath"] == "sample.wav" + assert "_temporary_audio_path" not in cleaned.data + assert "_materialization_mode" not in cleaned.data + assert "_manifest_audio_filepath" not in cleaned.data + + def test_materialize_segment_for_offset_entries(self, tmp_path: Path) -> None: + tar_path = tmp_path / "audio_0.tar" + _write_tar(tar_path, {"sample.wav": _make_wav_bytes(duration_sec=1.0)}) + + task = AudioTask( + task_id="t1", + dataset_name="ds", + data={ + "audio_filepath": "sample.wav-sub1", + "offset": 0.25, + "duration": 0.5, + "_tar_path": str(tar_path), + "_tar_member": "sample.wav", + }, + ) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp")) + [materialized] = materialize.process_batch([task]) + + temp_path = Path(materialized.data["_temporary_audio_path"]) + assert temp_path.exists() + assert materialized.data["_materialization_mode"] == "segment" + + with wave.open(temp_path.as_posix(), "rb") as wav_file: + assert wav_file.getframerate() == 16000 + assert wav_file.getnframes() == 8000 # 0.5 sec at 16kHz + + def test_materialize_segment_for_duration_only_entries(self, tmp_path: Path) -> None: + tar_path = tmp_path / "audio_0.tar" + _write_tar(tar_path, {"sample.wav": _make_wav_bytes(duration_sec=1.0)}) + + task = AudioTask( + task_id="t1", + dataset_name="ds", + data={ + "audio_filepath": "sample.wav", + "offset": 0.0, + "duration": 0.5, + "_tar_path": str(tar_path), + "_tar_member": "sample.wav", + }, + ) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp")) + [materialized] = materialize.process_batch([task]) + + temp_path = Path(materialized.data["_temporary_audio_path"]) + assert temp_path.exists() + assert materialized.data["_materialization_mode"] == "segment" + + with wave.open(temp_path.as_posix(), "rb") as wav_file: + assert wav_file.getnframes() == 8000 # 0.5 sec at 16kHz + + @pytest.mark.parametrize("duration", [0.0, -0.25]) + def test_materialize_segment_raises_for_non_positive_duration(self, tmp_path: Path, duration: float) -> None: + tar_path = tmp_path / "audio_0.tar" + _write_tar(tar_path, {"sample.wav": _make_wav_bytes(duration_sec=1.0)}) + + task = AudioTask( + task_id="t1", + dataset_name="ds", + data={ + "audio_filepath": "sample.wav", + "offset": 0.0, + "duration": duration, + "_tar_path": str(tar_path), + "_tar_member": "sample.wav", + }, + ) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp")) + + with pytest.raises(RuntimeError, match="Duration must be greater than 0"): + materialize.process_batch([task]) + + def test_materialize_segment_raises_for_offset_past_audio_end(self, tmp_path: Path) -> None: + tar_path = tmp_path / "audio_0.tar" + _write_tar(tar_path, {"sample.wav": _make_wav_bytes(duration_sec=0.25)}) + + task = AudioTask( + task_id="t1", + dataset_name="ds", + data={ + "audio_filepath": "sample.wav-sub1", + "offset": 1.0, + "duration": 0.25, + "_tar_path": str(tar_path), + "_tar_member": "sample.wav", + }, + ) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp")) + + with pytest.raises(RuntimeError, match=r"Offset 1\.0s exceeds audio length"): + materialize.process_batch([task]) + + def test_materialize_to_durable_directory_keeps_file_after_cleanup(self, tmp_path: Path) -> None: + tar_path = tmp_path / "audio_0.tar" + raw_audio = b"durable-bytes" + _write_tar(tar_path, {"sample.wav": raw_audio}) + + task = AudioTask( + task_id="t1", + dataset_name="ds", + sample_key="sample-key-1", + data={ + "audio_filepath": "sample.wav", + "_tar_path": str(tar_path), + "_tar_member": "sample.wav", + }, + ) + + materialization_dir = tmp_path / "materialized" + materialize = MaterializeTarredAudioStage(materialization_dir=str(materialization_dir)) + [materialized] = materialize.process_batch([task]) + + durable_path = Path(materialized.data["audio_filepath"]) + assert durable_path.exists() + assert durable_path.read_bytes() == raw_audio + assert durable_path.is_relative_to(materialization_dir) + assert "_temporary_audio_path" not in materialized.data + + cleanup = CleanupTemporaryAudioStage() + cleaned = cleanup.process(materialized) + + assert durable_path.exists() + assert cleaned.data["audio_filepath"] == "sample.wav" + + def test_pipe_transport_reads_manifest_and_tar(self, tmp_path: Path) -> None: + manifest = tmp_path / "manifest_0.json" + tar_path = tmp_path / "audio_0.tar" + manifest.write_text(json.dumps({"audio_filepath": "sample.wav", "text": "alpha"}) + "\n") + _write_tar(tar_path, {"sample.wav": b"pipe-bytes"}) + + manifest_cmd = ( + f'pipe:python3 -c "from pathlib import Path; import sys; ' + f"sys.stdout.buffer.write(Path(r'{manifest}').read_bytes())\"" + ) + tar_cmd = ( + f'pipe:python3 -c "from pathlib import Path; import sys; ' + f"sys.stdout.buffer.write(Path(r'{tar_path}').read_bytes())\"" + ) + + reader_stage = TarredAudioManifestReaderStage(tar_paths=tar_cmd, transport="auto") + [audio_task] = reader_stage.process(_make_file_group_task([manifest_cmd])) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp"), transport="auto") + [materialized] = materialize.process_batch([audio_task]) + + temp_path = Path(materialized.data["_temporary_audio_path"]) + assert temp_path.exists() + assert temp_path.read_bytes() == b"pipe-bytes" + + def test_pipe_transport_skips_missing_entries_during_materialization(self, tmp_path: Path) -> None: + manifest = tmp_path / "manifest_0.json" + tar_path = tmp_path / "audio_0.tar" + manifest.write_text( + "\n".join( + [ + json.dumps({"audio_filepath": "sample.wav", "text": "alpha"}), + json.dumps({"audio_filepath": "missing.wav", "text": "beta"}), + ] + ) + ) + _write_tar(tar_path, {"sample.wav": b"pipe-bytes"}) + + tar_cmd = ( + f'pipe:python3 -c "from pathlib import Path; import sys; ' + f"sys.stdout.buffer.write(Path(r'{tar_path}').read_bytes())\"" + ) + + reader_stage = TarredAudioManifestReaderStage( + tar_paths=tar_cmd, + transport="auto", + skip_missing_entries=True, + ) + audio_tasks = reader_stage.process(_make_file_group_task([str(manifest)])) + + assert len(audio_tasks) == 2 + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp"), transport="auto") + materialized = materialize.process_batch(audio_tasks) + + assert len(materialized) == 1 + assert materialized[0].data["text"] == "alpha" + + def test_pipe_transport_raises_for_missing_entries_when_skip_disabled(self, tmp_path: Path) -> None: + manifest = tmp_path / "manifest_0.json" + tar_path = tmp_path / "audio_0.tar" + manifest.write_text( + "\n".join( + [ + json.dumps({"audio_filepath": "sample.wav", "text": "alpha"}), + json.dumps({"audio_filepath": "missing.wav", "text": "beta"}), + ] + ) + ) + _write_tar(tar_path, {"sample.wav": b"pipe-bytes"}) + + tar_cmd = ( + f'pipe:python3 -c "from pathlib import Path; import sys; ' + f"sys.stdout.buffer.write(Path(r'{tar_path}').read_bytes())\"" + ) + + reader_stage = TarredAudioManifestReaderStage( + tar_paths=tar_cmd, + transport="auto", + skip_missing_entries=False, + ) + audio_tasks = reader_stage.process(_make_file_group_task([str(manifest)])) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp"), transport="auto") + + with pytest.raises(RuntimeError, match=r"Failed to materialize tar members \['missing\.wav'\]"): + materialize.process_batch(audio_tasks) + + def test_materialize_decodes_shared_member_once_for_segment_tasks( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + import nemo_curator.stages.audio.io.materialize as materialize_module + + tar_path = tmp_path / "audio_0.tar" + _write_tar(tar_path, {"sample.wav": _make_wav_bytes(duration_sec=1.0)}) + + tasks = [ + AudioTask( + task_id="t1", + dataset_name="ds", + data={ + "audio_filepath": "sample.wav-sub1", + "offset": 0.0, + "duration": 0.25, + "_tar_path": str(tar_path), + "_tar_member": "sample.wav", + }, + ), + AudioTask( + task_id="t2", + dataset_name="ds", + data={ + "audio_filepath": "sample.wav-sub2", + "offset": 0.25, + "duration": 0.25, + "_tar_path": str(tar_path), + "_tar_member": "sample.wav", + }, + ), + ] + + read_calls = 0 + original_read = materialize_module.soundfile.read + + def counting_read(*args: object, **kwargs: object) -> object: + nonlocal read_calls + read_calls += 1 + return original_read(*args, **kwargs) + + monkeypatch.setattr(materialize_module.soundfile, "read", counting_read) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp")) + materialize.process_batch(tasks) + + assert read_calls == 1 + + def test_manual_end_to_end_reader_materialize_consume_cleanup(self, tmp_path: Path) -> None: + manifest = tmp_path / "manifest_0.json" + tar_path = tmp_path / "audio_0.tar" + manifest.write_text(json.dumps({"audio_filepath": "sample.wav", "text": "alpha"}) + "\n") + _write_tar(tar_path, {"sample.wav": b"consumer-bytes"}) + + reader = TarredAudioManifestReader( + manifest_paths=str(manifest), + tar_paths=str(tar_path), + ) + partition_stage, reader_stage = reader.decompose() + [file_group] = partition_stage.process(_EmptyTask) + audio_tasks = reader_stage.process(file_group) + + materialize = MaterializeTarredAudioStage(temp_dir=str(tmp_path / "tmp")) + materialized = materialize.process_batch(audio_tasks) + + consumer = _PathConsumerStage() + consumed = consumer.process_batch(materialized) + assert consumed[0].data["_path_exists"] is True + + cleanup = CleanupTemporaryAudioStage() + cleaned = cleanup.process_batch(consumed) + + assert cleaned[0].data["audio_filepath"] == "sample.wav" + assert "_temporary_audio_path" not in cleaned[0].data diff --git a/tests/stages/file_io/__init__.py b/tests/stages/file_io/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/stages/file_io/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/stages/file_io/test_files.py b/tests/stages/file_io/test_files.py new file mode 100644 index 0000000000..b509dd2a13 --- /dev/null +++ b/tests/stages/file_io/test_files.py @@ -0,0 +1,146 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path + +import fsspec +import pytest + +from nemo_curator.stages.file_io import ( + DeleteFilesStage, + MaterializeFilesStage, + UploadFilesStage, + UploadManifestStage, +) +from nemo_curator.tasks import AudioTask, FileGroupTask + + +def _write_remote_bytes(path: str, payload: bytes) -> None: + with fsspec.open(path, "wb") as fout: + fout.write(payload) + + +class TestFileStages: + def test_materialize_files_stage_writes_nested_output_field(self, tmp_path: Path) -> None: + remote_path = f"memory://audio/{tmp_path.name}/sample.wav" + _write_remote_bytes(remote_path, b"sample-bytes") + + task = AudioTask( + task_id="t1", + dataset_name="ds", + data={"artifacts": {"remote": {"audio_path": remote_path}}}, + ) + + stage = MaterializeFilesStage( + source_field_path="artifacts.remote.audio_path", + output_field_path="artifacts.local.materialized_path", + temp_dir=str(tmp_path / "tmp"), + ) + [materialized] = stage.process_batch([task]) + + local_path = Path(materialized.data["artifacts"]["local"]["materialized_path"]) + assert local_path.exists() + assert local_path.read_bytes() == b"sample-bytes" + assert materialized.data["artifacts"]["remote"]["audio_path"] == remote_path + + def test_materialize_files_stage_writes_shared_durable_output_once( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + remote_path = f"memory://audio/{tmp_path.name}/shared.wav" + _write_remote_bytes(remote_path, b"shared-bytes") + + tasks = [ + AudioTask(task_id="t1", dataset_name="ds", data={"artifacts": {"remote": {"audio_path": remote_path}}}), + AudioTask(task_id="t2", dataset_name="ds", data={"artifacts": {"remote": {"audio_path": remote_path}}}), + ] + + write_calls = 0 + original_write_bytes = Path.write_bytes + + def counting_write_bytes(path_obj: Path, data: bytes) -> int: + nonlocal write_calls + write_calls += 1 + return original_write_bytes(path_obj, data) + + monkeypatch.setattr(Path, "write_bytes", counting_write_bytes) + + stage = MaterializeFilesStage( + source_field_path="artifacts.remote.audio_path", + output_field_path="artifacts.local.materialized_path", + materialization_dir=str(tmp_path / "cache"), + ) + materialized = stage.process_batch(tasks) + + first_path = materialized[0].data["artifacts"]["local"]["materialized_path"] + second_path = materialized[1].data["artifacts"]["local"]["materialized_path"] + + assert first_path == second_path + assert write_calls == 1 + + def test_upload_files_stage_uses_nested_paths(self, tmp_path: Path) -> None: + local_path = tmp_path / "sample.wav" + local_path.write_bytes(b"upload-me") + + task = AudioTask( + task_id="t1", + dataset_name="ds", + data={ + "artifacts": {"local": {"path": str(local_path)}}, + "names": {"object_key": "nested/sample.wav"}, + }, + ) + + stage = UploadFilesStage( + source_field_path="artifacts.local.path", + output_field_path="artifacts.remote.uri", + protocol="memory", + bucket="uploaded-files", + key_field_path="names.object_key", + ) + uploaded = stage.process(task) + + uploaded_uri = uploaded.data["artifacts"]["remote"]["uri"] + assert uploaded_uri == "memory://uploaded-files/nested/sample.wav" + with fsspec.open(uploaded_uri, "rb") as fin: + assert fin.read() == b"upload-me" + + def test_delete_files_stage_removes_nested_field_and_object(self, tmp_path: Path) -> None: + remote_path = f"memory://delete/{tmp_path.name}/sample.wav" + _write_remote_bytes(remote_path, b"delete-me") + + task = AudioTask( + task_id="t1", + dataset_name="ds", + data={"artifacts": {"remote": {"uri": remote_path}}}, + ) + + stage = DeleteFilesStage(source_field_path="artifacts.remote.uri") + deleted = stage.process(task) + + fs, resolved = fsspec.core.url_to_fs(remote_path) + assert not fs.exists(resolved) + assert "uri" not in deleted.data["artifacts"]["remote"] + + def test_upload_manifest_stage_uploads_file_group_outputs(self, tmp_path: Path) -> None: + manifest = tmp_path / "output.jsonl" + manifest.write_text(json.dumps({"audio_filepath": "a.wav"}) + "\n") + + stage = UploadManifestStage(protocol="memory", bucket="manifest-bucket", key_prefix="jsonl") + result = stage.process(FileGroupTask(task_id="fg", dataset_name="ds", data=[str(manifest)])) + + assert result.data == ["memory://manifest-bucket/jsonl/output.jsonl"] + with fsspec.open(result.data[0], "rt", encoding="utf-8") as fin: + assert json.loads(fin.read().strip())["audio_filepath"] == "a.wav" + assert result._metadata["local_source_files"] == [str(manifest)] diff --git a/tests/stages/text/io/reader/test_jsonl.py b/tests/stages/text/io/reader/test_jsonl.py index d51eae9ce0..54f126e3e5 100644 --- a/tests/stages/text/io/reader/test_jsonl.py +++ b/tests/stages/text/io/reader/test_jsonl.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from pathlib import Path import pandas as pd @@ -20,8 +21,8 @@ from nemo_curator.stages.deduplication.id_generator import ( CURATOR_DEDUP_ID_STR, ) -from nemo_curator.stages.text.io.reader.jsonl import JsonlReader, JsonlReaderStage -from nemo_curator.tasks import FileGroupTask, _EmptyTask +from nemo_curator.stages.text.io.reader.jsonl import JsonlAudioReaderStage, JsonlReader, JsonlReaderStage +from nemo_curator.tasks import AudioTask, FileGroupTask, _EmptyTask @pytest.fixture @@ -122,6 +123,233 @@ def fake_read_json(_path: object, *_args: object, **kwargs: object) -> pd.DataFr assert len(df) == 2 +class TestJsonlAudioReader: + """Tests for the audio-task JSONL reader path.""" + + def test_audio_stage_reads_audio_tasks(self, tmp_path: Path) -> None: + manifest = tmp_path / "audio.jsonl" + entries = [ + {"audio_filepath": "a.wav", "text": "alpha", "segments": [{"start": 0.0, "end": 1.0}]}, + {"audio_filepath": "b.wav", "text": "beta", "segments": [{"start": 1.0, "end": 2.0}]}, + ] + manifest.write_text("\n".join(json.dumps(entry) for entry in entries)) + + stage = JsonlAudioReaderStage() + task = FileGroupTask( + task_id="audio_task", + dataset_name="audio_dataset", + data=[str(manifest)], + _metadata={"source": "unit-test"}, + ) + result = stage.process(task) + + assert len(result) == 2 + assert all(isinstance(item, AudioTask) for item in result) + assert [item.data["audio_filepath"] for item in result] == ["a.wav", "b.wav"] + assert [item.task_id for item in result] == ["audio_task_0", "audio_task_1"] + assert all(item.dataset_name == "audio_dataset" for item in result) + assert all(item._metadata == {"source": "unit-test"} for item in result) + assert all(item.sample_key for item in result) + assert result[0].sample_key != result[1].sample_key + assert result[0].data["segments"][0]["end"] == 1.0 + + def test_audio_stage_filters_fields_and_skips_blank_lines(self, tmp_path: Path) -> None: + manifest = tmp_path / "audio_fields.jsonl" + manifest.write_text( + json.dumps({"audio_filepath": "a.wav", "text": "alpha", "speaker_id": "spk1"}) + + "\n\n \n" + + json.dumps({"audio_filepath": "b.wav", "text": "beta", "speaker_id": "spk2"}) + + "\n" + ) + + stage = JsonlAudioReaderStage(fields=["audio_filepath", "text"]) + task = FileGroupTask(task_id="audio_task", dataset_name="audio_dataset", data=[str(manifest)], _metadata={}) + result = stage.process(task) + + assert len(result) == 2 + assert result[0].data["audio_filepath"] == "a.wav" + assert result[0].data["text"] == "alpha" + assert result[0].data["sample_key"] + assert result[1].data["audio_filepath"] == "b.wav" + assert result[1].data["text"] == "beta" + assert result[1].data["sample_key"] + + def test_audio_stage_preserves_explicit_sample_key_from_manifest(self, tmp_path: Path) -> None: + manifest = tmp_path / "audio_sample_keys.jsonl" + manifest.write_text( + json.dumps({"audio_filepath": "a.wav", "text": "alpha", "sample_key": "explicit-a"}) + "\n" + ) + + stage = JsonlAudioReaderStage(fields=["audio_filepath", "text"]) + task = FileGroupTask(task_id="audio_task", dataset_name="audio_dataset", data=[str(manifest)], _metadata={}) + [result] = stage.process(task) + + assert result.sample_key == "explicit-a" + + def test_audio_composite_decomposes_to_audio_stage(self, tmp_path: Path) -> None: + manifest = tmp_path / "audio_manifest.jsonl" + manifest.write_text(json.dumps({"audio_filepath": "a.wav", "text": "alpha"}) + "\n") + + reader = JsonlReader( + file_paths=str(tmp_path), + task_type="audio", + fields=["audio_filepath", "text"], + read_kwargs={"storage_options": {"anon": True}}, + ) + stages = reader.decompose() + + assert len(stages) == 2 + assert getattr(stages[0], "storage_options", None) == {"anon": True} + assert isinstance(stages[1], JsonlAudioReaderStage) + assert stages[1].fields == ["audio_filepath", "text"] + + def test_audio_composite_propagates_id_generation_flags(self, tmp_path: Path) -> None: + manifest = tmp_path / "audio_manifest.jsonl" + manifest.write_text(json.dumps({"audio_filepath": "a.wav", "text": "alpha"}) + "\n") + + reader = JsonlReader(file_paths=str(tmp_path), task_type="audio", _generate_ids=True) + stages = reader.decompose() + + assert len(stages) == 2 + assert isinstance(stages[1], JsonlAudioReaderStage) + assert stages[1]._generate_ids is True + assert stages[1]._assign_ids is False + + def test_audio_pipeline_outputs_audio_tasks(self, tmp_path: Path) -> None: + from nemo_curator.backends.xenna import XennaExecutor + from nemo_curator.pipeline import Pipeline + + input_dir = tmp_path / "audio_inputs" + input_dir.mkdir() + for file_idx in range(2): + manifest = input_dir / f"audio_{file_idx}.jsonl" + entries = [ + {"audio_filepath": f"{file_idx}_0.wav", "text": f"doc-{file_idx}-0"}, + {"audio_filepath": f"{file_idx}_1.wav", "text": f"doc-{file_idx}-1"}, + ] + manifest.write_text("\n".join(json.dumps(entry) for entry in entries)) + + pipeline = Pipeline(name="audio_reader_test") + pipeline.add_stage(JsonlReader(file_paths=str(input_dir), files_per_partition=1, task_type="audio")) + + results = pipeline.run(XennaExecutor(config={"execution_mode": "streaming"})) + + assert results is not None + assert len(results) == 4 + assert all(isinstance(task, AudioTask) for task in results) + assert sorted(task.data["audio_filepath"] for task in results) == [ + "0_0.wav", + "0_1.wav", + "1_0.wav", + "1_1.wav", + ] + + @pytest.mark.usefixtures("ray_client_with_id_generator") + def test_audio_stage_generates_and_assigns_stable_ids(self, tmp_path: Path) -> None: + manifest = tmp_path / "audio_ids.jsonl" + manifest.write_text( + "\n".join( + [ + json.dumps({"audio_filepath": "a.wav", "text": "alpha"}), + json.dumps({"audio_filepath": "b.wav", "text": "beta"}), + ] + ) + ) + task = FileGroupTask(task_id="audio_task", dataset_name="audio_dataset", data=[str(manifest)], _metadata={}) + + generation_stage = JsonlAudioReaderStage(_generate_ids=True) + generation_stage.setup() + + generated = generation_stage.process(task) + generated_ids = [audio_task.data[CURATOR_DEDUP_ID_STR] for audio_task in generated] + assert generated_ids == [0, 1] + + repeated = generation_stage.process(task) + repeated_ids = [audio_task.data[CURATOR_DEDUP_ID_STR] for audio_task in repeated] + assert repeated_ids == [0, 1] + + assign_stage = JsonlAudioReaderStage(_assign_ids=True) + assign_stage.setup() + assigned = assign_stage.process(task) + assigned_ids = [audio_task.data[CURATOR_DEDUP_ID_STR] for audio_task in assigned] + assert assigned_ids == [0, 1] + + def test_audio_stage_assign_ids_tolerates_extra_registered_ids_from_blank_lines( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + class _FakeRemoteMethod: + def __init__(self, return_value: tuple[int, int]) -> None: + self.return_value = return_value + + def remote(self, *_args: object, **_kwargs: object) -> tuple[int, int]: + return self.return_value + + class _FakeIdGenerator: + def __init__(self, return_value: tuple[int, int]) -> None: + self.get_batch_range = _FakeRemoteMethod(return_value) + + manifest = tmp_path / "audio_blank_lines.jsonl" + manifest.write_text( + json.dumps({"audio_filepath": "a.wav", "text": "alpha"}) + + "\n\n \n" + + json.dumps({"audio_filepath": "b.wav", "text": "beta"}) + + "\n" + ) + task = FileGroupTask(task_id="audio_task", dataset_name="audio_dataset", data=[str(manifest)], _metadata={}) + + stage = JsonlAudioReaderStage(_assign_ids=True) + stage.id_generator = _FakeIdGenerator((10, 12)) + monkeypatch.setattr("nemo_curator.stages.text.io.reader.jsonl.ray.get", lambda value: value) + + assigned = stage.process(task) + assigned_ids = [audio_task.data[CURATOR_DEDUP_ID_STR] for audio_task in assigned] + + assert assigned_ids == [10, 11] + + def test_audio_stage_assign_ids_raises_clear_error_when_registered_range_is_too_short( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + class _FakeRemoteMethod: + def __init__(self, return_value: tuple[int, int]) -> None: + self.return_value = return_value + + def remote(self, *_args: object, **_kwargs: object) -> tuple[int, int]: + return self.return_value + + class _FakeIdGenerator: + def __init__(self, return_value: tuple[int, int]) -> None: + self.get_batch_range = _FakeRemoteMethod(return_value) + + manifest = tmp_path / "audio_short_range.jsonl" + manifest.write_text( + "\n".join( + [ + json.dumps({"audio_filepath": "a.wav", "text": "alpha"}), + json.dumps({"audio_filepath": "b.wav", "text": "beta"}), + ] + ) + ) + task = FileGroupTask(task_id="audio_task", dataset_name="audio_dataset", data=[str(manifest)], _metadata={}) + + stage = JsonlAudioReaderStage(_assign_ids=True) + stage.id_generator = _FakeIdGenerator((20, 20)) + monkeypatch.setattr("nemo_curator.stages.text.io.reader.jsonl.ray.get", lambda value: value) + + with pytest.raises(RuntimeError, match="contains 1 IDs, but the audio JSONL reader produced 2 tasks"): + stage.process(task) + + def test_audio_stage_generate_ids_no_actor_error(self) -> None: + stage = JsonlAudioReaderStage(_generate_ids=True) + + with pytest.raises(RuntimeError, match="actor 'id_generator' does not exist"): + stage.setup() + + stage = JsonlAudioReaderStage(_assign_ids=True) + + with pytest.raises(RuntimeError, match="actor 'id_generator' does not exist"): + stage.setup() + + class TestJsonlReaderWithIdGenerator: """Test JSONL reader with ID generation.""" diff --git a/tests/tasks/test_audio_task.py b/tests/tasks/test_audio_task.py index 1438c239ea..f1e7621247 100644 --- a/tests/tasks/test_audio_task.py +++ b/tests/tasks/test_audio_task.py @@ -17,6 +17,7 @@ from pathlib import Path from nemo_curator.tasks import AudioTask +from nemo_curator.tasks.audio_task import build_audio_sample_key def test_audio_task_stores_dict() -> None: @@ -49,3 +50,28 @@ def test_audio_task_validation_missing_file(tmp_path: Path) -> None: def test_audio_task_validation_no_filepath_key() -> None: entry = AudioTask(data={"text": "hello"}) assert entry.validate() is True + + +def test_audio_task_propagates_explicit_sample_key_from_data() -> None: + entry = AudioTask(data={"audio_filepath": "/x.wav", "sample_key": "sample-123"}) + assert entry.sample_key == "sample-123" + + +def test_audio_task_persists_constructor_sample_key_back_to_data() -> None: + entry = AudioTask(data={"audio_filepath": "/x.wav"}, sample_key="sample-456") + assert entry.sample_key == "sample-456" + assert entry.data["sample_key"] == "sample-456" + + +def test_build_audio_sample_key_is_stable_for_same_identity() -> None: + entry = { + "audio_filepath": "/a.wav", + "offset": 0.25, + "duration": 1.5, + } + + first = build_audio_sample_key(entry, dataset_name="dataset") + second = build_audio_sample_key(dict(entry), dataset_name="dataset") + + assert first == second + assert first