From 507b8c47ea1fd2c7668b7d7c22acf7fbec67cd2b Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Wed, 8 Apr 2026 14:39:14 +0000 Subject: [PATCH 01/11] Fix waveform tensor leak and redesign AudioDataFilterStage for all VAD/Speaker combos Fixes the OverflowError crash in JsonlWriter caused by torch.Tensor waveform leaking through the pipeline into serialization. Root cause: When VAD and SpeakerSeparation were disabled, no cleanup stage ran before AudioToDocumentStage, so MonoConversionStage's waveform tensor leaked into the DataFrame and crashed ujson. Changes: - Redesign AudioDataFilterStage.decompose() as a topology selector with 4 distinct pipeline builders (one per VAD/Speaker combo), each producing a self-consistent stage sequence ending with TimestampMapper. - Separate VAD (_make_vad) from quality filters (_append_quality_filters) for clarity: VAD is a segmentation stage, not a filter. - Fix TimestampMapper to handle all 4 combos: 1. segment_mappings (full pipeline with SegmentConcat) 2. start_ms/end_ms (VAD fan-out) 3. diar_segments (SpeakerSep with speaker timing) 4. duration fallback (filters only, whole file) - Replace confusing _STRIP_KEYS blacklist with clean two-layer output control: _NEVER_PASS_KEYS safety net (blocks tensors always) + passthrough_keys whitelist (user-configurable, sensible default). - Fix SpeakerSeparation to propagate diar_segments from diarization model so Combo 3 output has meaningful speaker timing data. - Standardize duration key to "duration" across all stages. - Add BandFilter warning on unexpected prediction values. - Update extract_segments.py to handle all 4 combos with auto-detection, metadata.csv generation, and generic score extraction. - Add --execution-mode flag to pipeline.py for batch/streaming control. - Update README.md and pipeline.yaml with topology documentation. --- .../audio_data_filter/audio_data_filter.py | 255 ++++---- nemo_curator/stages/audio/filtering/band.py | 2 + .../audio/postprocessing/timestamp_mapper.py | 161 +++-- .../audio/segmentation/speaker_separation.py | 11 +- .../speaker_separation_module/speaker_sep.py | 4 +- .../audio/segmentation/vad_segmentation.py | 5 +- .../postprocessing/test_timestamp_mapper.py | 485 +++++++-------- .../segmentation/test_speaker_separation.py | 14 +- .../segmentation/test_vad_segmentation.py | 6 +- tutorials/audio/readspeech/README.md | 163 +++-- .../audio/readspeech/extract_segments.py | 558 +++++++++++++----- tutorials/audio/readspeech/pipeline.py | 52 +- tutorials/audio/readspeech/pipeline.yaml | 11 +- 13 files changed, 1043 insertions(+), 684 deletions(-) diff --git a/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/audio_data_filter.py b/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/audio_data_filter.py index 742e4b8d20..0f17f9e085 100644 --- a/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/audio_data_filter.py +++ b/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/audio_data_filter.py @@ -63,6 +63,13 @@ class AudioDataFilterStage(CompositeStage): cross-file parallelism. Each stage owns its own default resource allocation. Use ``.with_()`` to override individual stage resources. + Supports four pipeline topologies based on which features are enabled: + + - **Combo 1** (VAD=off, Speaker=off): MonoConversion → Filters → TimestampMapper + - **Combo 2** (VAD=on, Speaker=off): MonoConversion → VAD(fan-out) → Filters → TimestampMapper + - **Combo 3** (VAD=off, Speaker=on): MonoConversion → Filters → SpeakerSep → Filters → TimestampMapper + - **Combo 4** (VAD=on, Speaker=on): Full pipeline with SegmentConcat + TimestampMapper + Args: config_path: Path to a YAML config file. When *None* the built-in ``default_config.yaml`` is used. @@ -84,136 +91,160 @@ def __init__( self._cfg = _deep_merge(self._cfg, config) def decompose(self) -> list[ProcessingStage]: + """Build a self-consistent pipeline topology based on enabled features.""" cfg = self._cfg - stages: list[ProcessingStage] = [] - mc = cfg.get("mono_conversion", {}) + enable_vad = cfg.get("vad", {}).get("enable", True) + enable_speaker = cfg.get("speaker_separation", {}).get("enable", True) + + if enable_vad and enable_speaker: + stages = self._build_full_pipeline(cfg) + elif enable_vad: + stages = self._build_vad_only_pipeline(cfg) + elif enable_speaker: + stages = self._build_speaker_only_pipeline(cfg) + else: + stages = self._build_filters_only_pipeline(cfg) + + enabled = get_enabled_stages(cfg) + logger.info( + f"AudioDataFilterStage decomposed into {len(stages)} stages " + f"(enabled: {enabled}, speaker_sep: {enable_speaker})" + ) + return stages + + # ------------------------------------------------------------------ + # Topology builders (one per feature combination) + # ------------------------------------------------------------------ + + def _build_full_pipeline(self, cfg: dict) -> list[ProcessingStage]: + """Combo 4: VAD=on, Speaker=on. Identical to the original design.""" + stages: list[ProcessingStage] = [self._make_mono(cfg)] + + stages.append(self._make_vad(cfg, suffix="", nested=True)) + self._append_quality_filters(stages, cfg, suffix="") + + concat = cfg.get("concatenation", {}) stages.append( - MonoConversionStage( - output_sample_rate=mc.get("output_sample_rate", 48000), - strict_sample_rate=mc.get("strict_sample_rate", True), - name="MonoConversion", - resources=Resources(cpus=mc.get("cpus", 1.0)), + SegmentConcatenationStage( + silence_duration_sec=concat.get("silence_duration_sec", 0.5), + name="SegmentConcat", + resources=Resources(cpus=concat.get("cpus", 1.0)), ) ) - vad = cfg.get("vad", {}) - band = cfg.get("band_filter", {}) - utmos = cfg.get("utmos", {}) - sigmos = cfg.get("sigmos", {}) - speaker = cfg.get("speaker_separation", {}) - concat = cfg.get("concatenation", {}) - ts = cfg.get("timestamp_mapper", {}) + stages.append(self._make_speaker_sep(cfg)) - enable_vad = vad.get("enable", True) - enable_band = band.get("enable", True) - enable_utmos = utmos.get("enable", True) - enable_sigmos = sigmos.get("enable", True) - enable_speaker = speaker.get("enable", True) - - self._append_filter_stages( - stages, - vad, - band, - utmos, - sigmos, - enable_vad, - enable_band, - enable_utmos, - enable_sigmos, - suffix="", - ) + stages.append(self._make_vad(cfg, suffix="_Speaker", nested=False)) + self._append_quality_filters(stages, cfg, suffix="_Speaker") - if enable_speaker: - if enable_vad: - stages.append( - SegmentConcatenationStage( - silence_duration_sec=concat.get("silence_duration_sec", 0.5), - name="SegmentConcat", - resources=Resources(cpus=concat.get("cpus", 1.0)), - ) - ) + stages.append(self._make_timestamp_mapper(cfg)) + return stages - stages.append( - SpeakerSeparationStage( - exclude_overlaps=speaker.get("exclude_overlaps", True), - min_duration=speaker.get("min_duration", 0.8), - gap_threshold=speaker.get("gap_threshold", 0.1), - buffer_time=speaker.get("buffer_time", 0.5), - name="SpeakerSeparation", - resources=Resources( - cpus=speaker.get("cpus", 1.0), - gpus=speaker.get("gpus", 1.0), - ), - ) - ) + def _build_vad_only_pipeline(self, cfg: dict) -> list[ProcessingStage]: + """Combo 2: VAD=on, Speaker=off. VAD fans out, OutputNormalizer cleans up.""" + stages: list[ProcessingStage] = [self._make_mono(cfg)] - self._append_filter_stages( - stages, - vad, - band, - utmos, - sigmos, - enable_vad, - enable_band, - enable_utmos, - enable_sigmos, - suffix="_Speaker", - ) + stages.append(self._make_vad(cfg, suffix="", nested=False)) + self._append_quality_filters(stages, cfg, suffix="") - if enable_vad or enable_speaker: - stages.append( - TimestampMapperStage( - passthrough_keys=ts.get("passthrough_keys"), - name="TimestampMapper", - resources=Resources(cpus=ts.get("cpus", 1.0)), - ) - ) + stages.append(self._make_timestamp_mapper(cfg)) + return stages - enabled = get_enabled_stages(cfg) - logger.info( - f"AudioDataFilterStage decomposed into {len(stages)} stages " - f"(enabled: {enabled}, speaker_sep: {enable_speaker})" - ) + def _build_speaker_only_pipeline(self, cfg: dict) -> list[ProcessingStage]: + """Combo 3: VAD=off, Speaker=on. SpeakerSep fans out with diar_segments.""" + stages: list[ProcessingStage] = [self._make_mono(cfg)] + + self._append_quality_filters(stages, cfg, suffix="") + + stages.append(self._make_speaker_sep(cfg)) + + self._append_quality_filters(stages, cfg, suffix="_Speaker") + + stages.append(self._make_timestamp_mapper(cfg)) return stages + def _build_filters_only_pipeline(self, cfg: dict) -> list[ProcessingStage]: + """Combo 1: VAD=off, Speaker=off. Filters only, TimestampMapper cleans up.""" + stages: list[ProcessingStage] = [self._make_mono(cfg)] + + self._append_quality_filters(stages, cfg, suffix="") + + stages.append(self._make_timestamp_mapper(cfg)) + return stages + + # ------------------------------------------------------------------ + # Stage factories + # ------------------------------------------------------------------ + @staticmethod - def _append_filter_stages( # noqa: PLR0913 + def _make_mono(cfg: dict) -> MonoConversionStage: + mc = cfg.get("mono_conversion", {}) + return MonoConversionStage( + output_sample_rate=mc.get("output_sample_rate", 48000), + strict_sample_rate=mc.get("strict_sample_rate", True), + name="MonoConversion", + resources=Resources(cpus=mc.get("cpus", 1.0)), + ) + + @staticmethod + def _make_vad(cfg: dict, *, suffix: str, nested: bool) -> VADSegmentationStage: + vad = cfg.get("vad", {}) + return VADSegmentationStage( + min_duration_sec=vad.get("min_duration_sec", 2.0), + max_duration_sec=vad.get("max_duration_sec", 60.0), + threshold=vad.get("threshold", 0.5), + min_interval_ms=vad.get("min_interval_ms", 500), + speech_pad_ms=vad.get("speech_pad_ms", 300), + nested=nested, + name=f"VAD{suffix}", + resources=Resources( + cpus=vad.get("cpus", 1.0), + gpus=vad.get("gpus", 0.3), + ), + ) + + @staticmethod + def _make_speaker_sep(cfg: dict) -> SpeakerSeparationStage: + speaker = cfg.get("speaker_separation", {}) + return SpeakerSeparationStage( + exclude_overlaps=speaker.get("exclude_overlaps", True), + min_duration=speaker.get("min_duration", 0.8), + gap_threshold=speaker.get("gap_threshold", 0.1), + buffer_time=speaker.get("buffer_time", 0.5), + name="SpeakerSeparation", + resources=Resources( + cpus=speaker.get("cpus", 1.0), + gpus=speaker.get("gpus", 1.0), + ), + ) + + @staticmethod + def _make_timestamp_mapper(cfg: dict) -> TimestampMapperStage: + ts = cfg.get("timestamp_mapper", {}) + return TimestampMapperStage( + passthrough_keys=ts.get("passthrough_keys"), + name="TimestampMapper", + resources=Resources(cpus=ts.get("cpus", 1.0)), + ) + + # ------------------------------------------------------------------ + # Quality filter helpers + # ------------------------------------------------------------------ + + @staticmethod + def _append_quality_filters( stages: list[ProcessingStage], - vad: dict, - band: dict, - utmos: dict, - sigmos: dict, - enable_vad: bool, - enable_band: bool, - enable_utmos: bool, - enable_sigmos: bool, + cfg: dict, *, suffix: str, ) -> None: - """Append VAD + quality filter stages to *stages* list.""" - if enable_vad: - # Pre-speaker pass (suffix==""): nested=True so VAD stores segments - # inside the task for SegmentConcatenation to merge. - # Post-speaker pass (suffix=="_Speaker"): nested=False so VAD fans - # out into separate tasks for independent downstream processing. - stages.append( - VADSegmentationStage( - min_duration_sec=vad.get("min_duration_sec", 2.0), - max_duration_sec=vad.get("max_duration_sec", 60.0), - threshold=vad.get("threshold", 0.5), - min_interval_ms=vad.get("min_interval_ms", 500), - speech_pad_ms=vad.get("speech_pad_ms", 300), - nested=(suffix == ""), - name=f"VAD{suffix}", - resources=Resources( - cpus=vad.get("cpus", 1.0), - gpus=vad.get("gpus", 0.3), - ), - ) - ) + """Append quality filter stages (Band, UTMOS, SIGMOS) to *stages*.""" + band = cfg.get("band_filter", {}) + utmos = cfg.get("utmos", {}) + sigmos = cfg.get("sigmos", {}) - if enable_band: + if band.get("enable", True): stages.append( BandFilterStage( band_value=band.get("band_value", "full_band"), @@ -225,7 +256,7 @@ def _append_filter_stages( # noqa: PLR0913 ) ) - if enable_utmos: + if utmos.get("enable", True): stages.append( UTMOSFilterStage( mos_threshold=utmos.get("mos_threshold", 3.5), @@ -237,7 +268,7 @@ def _append_filter_stages( # noqa: PLR0913 ) ) - if enable_sigmos: + if sigmos.get("enable", True): stages.append( SIGMOSFilterStage( noise_threshold=sigmos.get("noise_threshold", 4.0), diff --git a/nemo_curator/stages/audio/filtering/band.py b/nemo_curator/stages/audio/filtering/band.py index 803f5cad94..be2b34c9bd 100644 --- a/nemo_curator/stages/audio/filtering/band.py +++ b/nemo_curator/stages/audio/filtering/band.py @@ -181,6 +181,8 @@ def _process_single(self, task: AudioTask) -> AudioTask | None: pred = self._predictor.predict_audio(waveform, sample_rate) if isinstance(pred, str) and not pred.startswith("Error") and pred in ("full_band", "narrow_band"): task.data["band_prediction"] = pred + else: + logger.warning(f"[{task.task_id}] BandFilter: unexpected prediction value: {pred!r}") except Exception as e: # noqa: BLE001 logger.exception(f"[BandFilter] Prediction error: {e}") return None diff --git a/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py b/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py index e6a1b7be47..622602513f 100644 --- a/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py +++ b/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py @@ -15,11 +15,24 @@ """ Timestamp mapper stage. -Resolves segment positions in the concatenated waveform back to -positions in the original audio file using segment mappings stored -in ``task._metadata["segment_mappings"]`` by SegmentConcatenationStage. - -Strips waveform from final output items (metadata-only output). +Normalizes task data at the pipeline output boundary. Handles four +sources of timing information (checked in priority order): + +1. ``segment_mappings`` in ``task._metadata`` -- remaps concat-space + positions back to original file positions. +2. ``start_ms`` / ``end_ms`` in ``task.data`` -- uses them directly + as original positions (from VAD fan-out). +3. ``diar_segments`` in ``task.data`` -- computes span from first + segment start to last segment end (from SpeakerSep). +4. ``duration`` fallback -- uses whole-file duration. + +Output control uses two layers: + +- **passthrough_keys** (whitelist): only keys in this list are copied + from the input to the output. Defaults to all built-in quality + filter and speaker metadata keys. Users can override via config. +- **_NEVER_PASS_KEYS** (safety net): non-serializable keys that are + always blocked, even if accidentally added to ``passthrough_keys``. """ from dataclasses import dataclass, field @@ -31,6 +44,32 @@ from nemo_curator.stages.resources import Resources from nemo_curator.tasks import AudioTask +_NEVER_PASS_KEYS = frozenset( + { + "waveform", + "audio", + "audio_data", + "audio_array", + "segments", + } +) + +_DEFAULT_PASSTHROUGH_KEYS: list[str] = [ + "speaker_id", + "num_speakers", + "speaking_duration", + "sample_rate", + "utmos_mos", + "sigmos_noise", + "sigmos_ovrl", + "sigmos_sig", + "sigmos_col", + "sigmos_disc", + "sigmos_loud", + "sigmos_reverb", + "band_prediction", +] + def _translate_to_original( mappings: list[dict[str, Any]], concat_start_ms: int, concat_end_ms: int @@ -65,42 +104,41 @@ def _translate_to_original( @dataclass class TimestampMapperStage(ProcessingStage[AudioTask, AudioTask]): """ - Map segment positions back to original file timestamps. + Normalize task data at the pipeline output boundary. - Reads ``task._metadata["segment_mappings"]`` (written by - SegmentConcatenationStage) and translates the task's - ``start_ms`` / ``end_ms`` to ``original_start_ms`` / - ``original_end_ms`` in the source file. + Constructs core output fields (``original_file``, timestamps, + duration) from available timing sources, then copies only the + keys listed in ``passthrough_keys`` from the input. - Strips ``waveform`` from output so the final output is - metadata-only (timestamps, quality scores, speaker info). + Args: + passthrough_keys: Keys to copy from input to output. + Defaults to all built-in quality filter and speaker + metadata keys. Override to include custom fields or + restrict the output schema. """ - passthrough_keys: list[str] | None = None + passthrough_keys: list[str] | None = field(default=None) name: str = "TimestampMapper" batch_size: int = 1 resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) - _STRIP_KEYS = frozenset( - { - "waveform", - "audio", - "audio_filepath", - "start_ms", - "end_ms", - "segment_num", - "original_file", - } - ) - def __post_init__(self): super().__init__() + if self.passthrough_keys is None: + self.passthrough_keys = list(_DEFAULT_PASSTHROUGH_KEYS) + blocked = set(self.passthrough_keys) & _NEVER_PASS_KEYS + if blocked: + logger.warning( + f"[TimestampMapper] passthrough_keys contains non-serializable " + f"keys that will be blocked: {sorted(blocked)}. " + f"These keys are never included in output." + ) def inputs(self) -> tuple[list[str], list[str]]: return [], [] def outputs(self) -> tuple[list[str], list[str]]: - return [], ["original_file", "original_start_ms", "original_end_ms", "duration_ms", "duration_sec"] + return [], ["original_file", "original_start_ms", "original_end_ms", "duration_ms", "duration"] def process(self, task: AudioTask) -> AudioTask | list[AudioTask]: mappings = (task._metadata or {}).get("segment_mappings") @@ -140,14 +178,11 @@ def process(self, task: AudioTask) -> AudioTask | list[AudioTask]: return task def _copy_passthrough(self, item: dict[str, Any], result: dict[str, Any]) -> None: - if self.passthrough_keys is not None: - for key in self.passthrough_keys: - if key in item and item[key] is not None and key not in result: - result[key] = item[key] - else: - for key, val in item.items(): - if key not in self._STRIP_KEYS and key not in result and val is not None: - result[key] = val + for key in self.passthrough_keys: + if key in _NEVER_PASS_KEYS: + continue + if key in item and item[key] is not None and key not in result: + result[key] = item[key] def _build_output_item(self, item: dict[str, Any], orig: dict[str, Any]) -> dict[str, Any]: result: dict[str, Any] = { @@ -155,31 +190,53 @@ def _build_output_item(self, item: dict[str, Any], orig: dict[str, Any]) -> dict "original_start_ms": orig["original_start_ms"], "original_end_ms": orig["original_end_ms"], "duration_ms": orig["duration_ms"], - "duration_sec": orig["duration_ms"] / 1000.0, + "duration": orig["duration_ms"] / 1000.0, } self._copy_passthrough(item, result) return result def _build_output_item_no_mapping(self, item: dict[str, Any]) -> dict[str, Any]: - start_ms = item.get("start_ms", 0) - end_ms = item.get("end_ms", 0) - duration_ms = end_ms - start_ms - if duration_ms <= 0: - dur = item.get("duration") or item.get("duration_sec") - if dur is not None and float(dur) > 0: - duration_ms = int(float(dur) * 1000) - end_ms = start_ms + duration_ms - elif "waveform" in item and "sample_rate" in item: - wf = item["waveform"] - n = wf.shape[-1] if hasattr(wf, "shape") else len(wf) - duration_ms = int(n / item["sample_rate"] * 1000) - end_ms = start_ms + duration_ms result: dict[str, Any] = { "original_file": item.get("original_file", item.get("audio_filepath", "unknown")), - "original_start_ms": start_ms, - "original_end_ms": end_ms, - "duration_ms": duration_ms, - "duration_sec": duration_ms / 1000.0, } + + start_ms = item.get("start_ms") + end_ms = item.get("end_ms") + + if start_ms is not None and end_ms is not None and end_ms > start_ms: + result["original_start_ms"] = int(start_ms) + result["original_end_ms"] = int(end_ms) + result["duration_ms"] = int(end_ms - start_ms) + result["duration"] = (end_ms - start_ms) / 1000.0 + self._copy_passthrough(item, result) + return result + + diar_segments = item.get("diar_segments") + if diar_segments and len(diar_segments) > 0: + first_start = diar_segments[0][0] + last_end = diar_segments[-1][1] + result["original_start_ms"] = int(first_start * 1000) + result["original_end_ms"] = int(last_end * 1000) + result["duration_ms"] = int((last_end - first_start) * 1000) + result["duration"] = last_end - first_start + speaking = sum(end - start for start, end in diar_segments) + result["speaking_duration"] = round(speaking, 3) + result["diar_segments"] = [[round(s, 3), round(e, 3)] for s, e in diar_segments] + self._copy_passthrough(item, result) + return result + + dur = item.get("duration") + if dur is not None and float(dur) > 0: + duration_ms = int(float(dur) * 1000) + result["original_start_ms"] = 0 + result["original_end_ms"] = duration_ms + result["duration_ms"] = duration_ms + result["duration"] = float(dur) + else: + result["original_start_ms"] = 0 + result["original_end_ms"] = 0 + result["duration_ms"] = 0 + result["duration"] = 0.0 + self._copy_passthrough(item, result) return result diff --git a/nemo_curator/stages/audio/segmentation/speaker_separation.py b/nemo_curator/stages/audio/segmentation/speaker_separation.py index 31641a2e4b..858bce831b 100755 --- a/nemo_curator/stages/audio/segmentation/speaker_separation.py +++ b/nemo_curator/stages/audio/segmentation/speaker_separation.py @@ -100,7 +100,7 @@ def inputs(self) -> tuple[list[str], list[str]]: return [], [] def outputs(self) -> tuple[list[str], list[str]]: - return [], ["waveform", "sample_rate", "speaker_id", "num_speakers", "duration_sec"] + return [], ["waveform", "sample_rate", "speaker_id", "num_speakers", "duration"] def ray_stage_spec(self) -> dict[str, Any]: return {RayStageSpecKeys.IS_FANOUT_STAGE: True} @@ -158,6 +158,8 @@ def _initialize_separator(self) -> None: logger.error(f"Failed to load speaker separator: {e}") raise + _INHERITED_DROP_KEYS = frozenset({"audio", "waveform", "duration", "num_samples"}) + def _build_speaker_tasks( self, speaker_audio_data: dict, @@ -167,18 +169,19 @@ def _build_speaker_tasks( """Build AudioTask list from speaker audio data.""" results: list[AudioTask] = [] num_speakers = len(speaker_audio_data) - for speaker_id, (speaker_audio_pydub, duration) in speaker_audio_data.items(): + for speaker_id, (speaker_audio_pydub, duration, diar_segments) in speaker_audio_data.items(): if duration < self.min_duration: logger.debug(f"Skipping {speaker_id}: duration {duration:.2f}s < {self.min_duration}s") continue spk_waveform, spk_sr = _pydub_to_waveform_sr(speaker_audio_pydub) speaker_data = { - **{k: v for k, v in item.items() if k not in ("audio", "waveform")}, + **{k: v for k, v in item.items() if k not in self._INHERITED_DROP_KEYS}, "waveform": spk_waveform, "sample_rate": spk_sr, "speaker_id": speaker_id, "num_speakers": num_speakers, - "duration_sec": duration, + "duration": duration, + "diar_segments": diar_segments, } spk_task = AudioTask( data=speaker_data, diff --git a/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py b/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py index 52cf44abaa..715d6452f8 100755 --- a/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py +++ b/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py @@ -456,7 +456,7 @@ def get_speaker_audio_data( # noqa: PLR0913, C901, PLR0912 exclude_overlaps: bool | None = None, min_duration: float | None = None, buffer_time: float | None = None, - ) -> dict[str, tuple[AudioSegment, float]]: + ) -> dict[str, tuple[AudioSegment, float, list[tuple[float, float]]]]: """ Process an audio file or waveform and return AudioSegment objects for each speaker. """ @@ -521,7 +521,7 @@ def get_speaker_audio_data( # noqa: PLR0913, C901, PLR0912 if silent_audio.rms < 1: continue - speaker_audio[speaker] = (silent_audio, total_duration) + speaker_audio[speaker] = (silent_audio, total_duration, segments) # Free the original audio to release memory before returning del original_audio diff --git a/nemo_curator/stages/audio/segmentation/vad_segmentation.py b/nemo_curator/stages/audio/segmentation/vad_segmentation.py index bdcc6c790c..1dda9b1159 100755 --- a/nemo_curator/stages/audio/segmentation/vad_segmentation.py +++ b/nemo_curator/stages/audio/segmentation/vad_segmentation.py @@ -104,7 +104,7 @@ def inputs(self) -> tuple[list[str], list[str]]: return [], [] def outputs(self) -> tuple[list[str], list[str]]: - return [], ["waveform", "sample_rate", "start_ms", "end_ms", "segment_num", "duration_sec"] + return [], ["waveform", "sample_rate", "start_ms", "end_ms", "segment_num", "duration"] def ray_stage_spec(self) -> dict[str, Any]: if self.nested: @@ -183,7 +183,6 @@ def _build_segment_item( "start_ms", "end_ms", "segment_num", - "duration_sec", "duration", "num_samples", ) @@ -195,7 +194,7 @@ def _build_segment_item( "start_ms": start_ms, "end_ms": end_ms, "segment_num": segment_num, - "duration_sec": (end_ms - start_ms) / 1000.0, + "duration": (end_ms - start_ms) / 1000.0, "original_file": item.get("original_file", item.get("audio_filepath", "unknown")), } ) diff --git a/tests/stages/audio/postprocessing/test_timestamp_mapper.py b/tests/stages/audio/postprocessing/test_timestamp_mapper.py index ed011a3e98..47e6af96fd 100644 --- a/tests/stages/audio/postprocessing/test_timestamp_mapper.py +++ b/tests/stages/audio/postprocessing/test_timestamp_mapper.py @@ -12,293 +12,218 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unit tests for TimestampMapperStage.""" - import torch from nemo_curator.stages.audio.postprocessing.timestamp_mapper import ( + _NEVER_PASS_KEYS, TimestampMapperStage, - _translate_to_original, ) from nemo_curator.tasks import AudioTask -SAMPLE_MAPPINGS = [ - { - "original_file": "/data/audio.wav", - "original_start_ms": 2000, - "original_end_ms": 5000, - "concat_start_ms": 0, - "concat_end_ms": 3000, - "segment_index": 0, - }, - { - "original_file": "/data/audio.wav", - "original_start_ms": 12000, - "original_end_ms": 17000, - "concat_start_ms": 3500, - "concat_end_ms": 8500, - "segment_index": 1, - }, - { - "original_file": "/data/audio.wav", - "original_start_ms": 45000, - "original_end_ms": 49000, - "concat_start_ms": 9000, - "concat_end_ms": 13000, - "segment_index": 2, - }, -] - -def _make_task(item: dict, mappings: list | None = None, task_id: str = "test") -> AudioTask: - metadata = {} - if mappings is not None: - metadata["segment_mappings"] = mappings - task = AudioTask( - data=item, - task_id=task_id, - dataset_name="test", +def _make_task(data: dict, task_id: str = "test", metadata: dict | None = None) -> AudioTask: + t = AudioTask(data=data, task_id=task_id, dataset_name="test_ds") + if metadata: + t._metadata = metadata + return t + + +def test_combo4_with_segment_mappings() -> None: + """Full pipeline: remaps concat-space timestamps to original file positions.""" + mappings = [ + { + "concat_start_ms": 0, + "concat_end_ms": 2000, + "original_file": "test.wav", + "original_start_ms": 5000, + "original_end_ms": 7000, + }, + ] + task = _make_task( + {"waveform": torch.randn(1, 48000), "sample_rate": 48000, "start_ms": 100, "end_ms": 1500, "utmos_mos": 4.2}, + metadata={"segment_mappings": mappings}, ) - task._metadata = metadata - return task - - -class TestTranslateToOriginal: - def test_segment_within_single_mapping(self): - results = _translate_to_original(SAMPLE_MAPPINGS, 500, 2500) - assert len(results) == 1 - assert results[0]["original_file"] == "/data/audio.wav" - assert results[0]["original_start_ms"] == 2500 - assert results[0]["original_end_ms"] == 4500 - assert results[0]["duration_ms"] == 2000 - - def test_segment_spans_two_mappings(self): - results = _translate_to_original(SAMPLE_MAPPINGS, 2000, 5000) - assert len(results) == 2 - assert results[0]["original_start_ms"] == 4000 - assert results[0]["original_end_ms"] == 5000 - assert results[1]["original_start_ms"] == 12000 - assert results[1]["original_end_ms"] == 13500 - - def test_segment_in_silence_gap(self): - results = _translate_to_original(SAMPLE_MAPPINGS, 3000, 3500) - assert len(results) == 0 - - def test_segment_no_overlap(self): - results = _translate_to_original(SAMPLE_MAPPINGS, 14000, 15000) - assert len(results) == 0 - - def test_segment_covers_entire_mapping(self): - results = _translate_to_original(SAMPLE_MAPPINGS, 0, 3000) - assert len(results) == 1 - assert results[0]["original_start_ms"] == 2000 - assert results[0]["original_end_ms"] == 5000 - assert results[0]["duration_ms"] == 3000 - - def test_malformed_mapping_skipped(self): - bad_mappings = [{"concat_start_ms": 0, "concat_end_ms": 1000}] - results = _translate_to_original(bad_mappings, 0, 500) - assert len(results) == 0 - - def test_empty_mappings(self): - results = _translate_to_original([], 0, 1000) - assert len(results) == 0 - - -class TestTimestampMapperWithMappings: - def test_single_segment_maps_correctly(self): - stage = TimestampMapperStage() - task = _make_task( - {"start_ms": 500, "end_ms": 2500, "speaker_id": "speaker_0", "utmos_mos": 4.2}, - mappings=SAMPLE_MAPPINGS, - ) - - result = stage.process(task) - - assert isinstance(result, AudioTask) - assert result.data["original_file"] == "/data/audio.wav" - assert result.data["original_start_ms"] == 2500 - assert result.data["original_end_ms"] == 4500 - assert result.data["duration_ms"] == 2000 - assert result.data["duration_sec"] == 2.0 - assert result.data["speaker_id"] == "speaker_0" - assert result.data["utmos_mos"] == 4.2 - - def test_cross_boundary_segment_rejected(self): - stage = TimestampMapperStage() - task = _make_task({"start_ms": 2000, "end_ms": 5000}, mappings=SAMPLE_MAPPINGS) - - result = stage.process(task) - - assert result == [] - - def test_segment_in_silence_gap_produces_no_output(self): - stage = TimestampMapperStage() - task = _make_task({"start_ms": 3000, "end_ms": 3500}, mappings=SAMPLE_MAPPINGS) - - result = stage.process(task) - - assert result == [] - - def test_invalid_range_skipped(self): - stage = TimestampMapperStage() - task = _make_task({"start_ms": 5000, "end_ms": 2000}, mappings=SAMPLE_MAPPINGS) - - result = stage.process(task) - - assert result == [] - - def test_waveform_stripped_from_output(self): - stage = TimestampMapperStage() - task = _make_task( - { - "start_ms": 0, - "end_ms": 3000, - "waveform": torch.randn(1, 48000), - "audio": b"fake", - "audio_filepath": "/data/audio.wav", - "segment_num": 0, - "speaker_id": "speaker_0", - }, - mappings=SAMPLE_MAPPINGS, - ) - - result = stage.process(task) - - assert isinstance(result, AudioTask) - assert "waveform" not in result.data - assert "audio" not in result.data - assert "audio_filepath" not in result.data - assert "segment_num" not in result.data - assert "speaker_id" in result.data - - -class TestTimestampMapperNoMappings: - def test_no_mapping_uses_start_end_directly(self): - stage = TimestampMapperStage() - task = _make_task( - {"start_ms": 1000, "end_ms": 4000, "original_file": "/data/audio.wav", "speaker_id": "speaker_0"}, - mappings=None, - ) - - result = stage.process(task) - - assert isinstance(result, AudioTask) - assert result.data["original_file"] == "/data/audio.wav" - assert result.data["original_start_ms"] == 1000 - assert result.data["original_end_ms"] == 4000 - assert result.data["duration_ms"] == 3000 - assert result.data["duration_sec"] == 3.0 - - def test_no_mapping_falls_back_to_duration_sec(self): - stage = TimestampMapperStage() - task = _make_task({"original_file": "/data/audio.wav", "duration_sec": 5.0}, mappings=None) - - result = stage.process(task) - - assert isinstance(result, AudioTask) - assert result.data["duration_ms"] == 5000 - - def test_no_mapping_falls_back_to_waveform_length(self): - stage = TimestampMapperStage() - task = _make_task( - {"original_file": "/data/audio.wav", "waveform": torch.randn(1, 96000), "sample_rate": 48000}, - mappings=None, - ) - - result = stage.process(task) - - assert isinstance(result, AudioTask) - assert result.data["duration_ms"] == 2000 - assert "waveform" not in result.data - - def test_no_mapping_uses_audio_filepath_as_fallback(self): - stage = TimestampMapperStage() - task = _make_task({"audio_filepath": "/data/fallback.wav", "start_ms": 0, "end_ms": 1000}, mappings=None) - - result = stage.process(task) - - assert result.data["original_file"] == "/data/fallback.wav" - - -class TestPassthroughKeys: - def test_default_passes_all_non_stripped_keys(self): - stage = TimestampMapperStage() - task = _make_task( - { - "start_ms": 0, - "end_ms": 3000, - "speaker_id": "speaker_0", - "utmos_mos": 4.2, - "band_prediction": "full_band", - "sample_rate": 48000, - "is_mono": True, - }, - mappings=SAMPLE_MAPPINGS, - ) - - result = stage.process(task) - - assert result.data["speaker_id"] == "speaker_0" - assert result.data["utmos_mos"] == 4.2 - assert result.data["band_prediction"] == "full_band" - assert result.data["sample_rate"] == 48000 - assert result.data["is_mono"] is True - - def test_explicit_passthrough_keys_filters_output(self): - stage = TimestampMapperStage(passthrough_keys=["speaker_id", "utmos_mos"]) - task = _make_task( - { - "start_ms": 0, - "end_ms": 3000, - "speaker_id": "speaker_0", - "utmos_mos": 4.2, - "band_prediction": "full_band", - "sample_rate": 48000, - "is_mono": True, - }, - mappings=SAMPLE_MAPPINGS, - ) - - result = stage.process(task) - - assert result.data["speaker_id"] == "speaker_0" - assert result.data["utmos_mos"] == 4.2 - assert "band_prediction" not in result.data - assert "sample_rate" not in result.data - assert "is_mono" not in result.data - assert "original_file" in result.data - assert "duration_ms" in result.data - - def test_passthrough_keys_missing_key_ignored(self): - stage = TimestampMapperStage(passthrough_keys=["speaker_id", "nonexistent_key"]) - task = _make_task({"start_ms": 0, "end_ms": 3000, "speaker_id": "speaker_0"}, mappings=SAMPLE_MAPPINGS) - - result = stage.process(task) - - assert result.data["speaker_id"] == "speaker_0" - assert "nonexistent_key" not in result.data - - -class TestTimestampMapperParams: - def test_default_passthrough_keys(self): - stage = TimestampMapperStage() - assert stage.passthrough_keys is None - - def test_custom_passthrough_keys(self): - stage = TimestampMapperStage(passthrough_keys=["a", "b"]) - assert stage.passthrough_keys == ["a", "b"] - - -class TestEdgeCases: - def test_none_values_not_passed_through(self): - stage = TimestampMapperStage() - task = _make_task( - {"start_ms": 0, "end_ms": 3000, "speaker_id": None, "utmos_mos": 4.0}, - mappings=SAMPLE_MAPPINGS, - ) - - result = stage.process(task) + stage = TimestampMapperStage() + result = stage.process(task) + + assert result.data["original_file"] == "test.wav" + assert result.data["original_start_ms"] == 5100 + assert result.data["original_end_ms"] == 6500 + assert result.data["duration_ms"] == 1400 + assert result.data["utmos_mos"] == 4.2 + assert result.data["sample_rate"] == 48000 + assert "waveform" not in result.data + assert "start_ms" not in result.data + + +def test_combo2_vad_fanout_start_end() -> None: + """VAD fan-out: uses start_ms/end_ms directly.""" + task = _make_task( + { + "waveform": torch.randn(1, 48000), + "sample_rate": 48000, + "start_ms": 5200, + "end_ms": 15400, + "segment_num": 0, + "duration": 10.2, + "original_file": "/a.wav", + "utmos_mos": 4.2, + } + ) + stage = TimestampMapperStage() + result = stage.process(task) + + assert result.data["original_file"] == "/a.wav" + assert result.data["original_start_ms"] == 5200 + assert result.data["original_end_ms"] == 15400 + assert result.data["duration_ms"] == 10200 + assert abs(result.data["duration"] - 10.2) < 0.01 + assert result.data["utmos_mos"] == 4.2 + assert "waveform" not in result.data + assert "start_ms" not in result.data + assert "segment_num" not in result.data + + +def test_combo3_diar_segments() -> None: + """Speaker-only: computes span from diar_segments.""" + task = _make_task( + { + "waveform": torch.randn(1, 48000), + "sample_rate": 48000, + "speaker_id": "speaker_0", + "num_speakers": 3, + "duration": 42.6, + "diar_segments": [(5.2, 15.4), (30.1, 42.8), (100.0, 120.5)], + "audio_filepath": "/a.wav", + "sigmos_noise": 4.5, + } + ) + stage = TimestampMapperStage() + result = stage.process(task) + + assert result.data["original_file"] == "/a.wav" + assert result.data["original_start_ms"] == 5200 + assert result.data["original_end_ms"] == 120500 + assert result.data["duration_ms"] == 115300 + assert abs(result.data["speaking_duration"] - 43.4) < 0.01 + assert len(result.data["diar_segments"]) == 3 + assert result.data["speaker_id"] == "speaker_0" + assert result.data["num_speakers"] == 3 + assert result.data["sigmos_noise"] == 4.5 + assert "waveform" not in result.data + + +def test_combo1_duration_fallback() -> None: + """Filters-only: uses duration from MonoConversion.""" + task = _make_task( + { + "audio_filepath": "/a.wav", + "waveform": torch.randn(1, 48000), + "sample_rate": 48000, + "duration": 10.5, + "is_mono": True, + "num_samples": 504000, + "sigmos_ovrl": 3.5, + } + ) + stage = TimestampMapperStage() + result = stage.process(task) + + assert result.data["original_file"] == "/a.wav" + assert result.data["original_start_ms"] == 0 + assert result.data["original_end_ms"] == 10500 + assert result.data["duration"] == 10.5 + assert result.data["sigmos_ovrl"] == 3.5 + assert result.data["sample_rate"] == 48000 + assert "waveform" not in result.data + assert "is_mono" not in result.data + assert "num_samples" not in result.data + + +def test_never_pass_keys_blocked() -> None: + """Non-serializable keys are blocked even if in passthrough_keys.""" + task = _make_task( + { + "audio_filepath": "/a.wav", + "waveform": torch.randn(1, 48000), + "segments": [{"waveform": torch.randn(1, 100)}], + "duration": 1.0, + "sigmos_ovrl": 3.0, + } + ) + stage = TimestampMapperStage(passthrough_keys=["waveform", "segments", "sigmos_ovrl"]) + result = stage.process(task) + + for key in _NEVER_PASS_KEYS: + assert key not in result.data, f"{key!r} must never pass through" + assert result.data["sigmos_ovrl"] == 3.0 + + +def test_default_passthrough_covers_all_filters() -> None: + """Default passthrough_keys includes all built-in filter scores.""" + task = _make_task( + { + "audio_filepath": "/a.wav", + "duration": 1.0, + "utmos_mos": 4.2, + "sigmos_noise": 4.0, + "sigmos_ovrl": 3.5, + "sigmos_sig": 3.8, + "sigmos_col": 4.0, + "sigmos_disc": 4.2, + "sigmos_loud": 3.7, + "sigmos_reverb": 4.9, + "band_prediction": "full_band", + "sample_rate": 48000, + } + ) + stage = TimestampMapperStage() + result = stage.process(task) + + assert result.data["utmos_mos"] == 4.2 + assert result.data["sigmos_noise"] == 4.0 + assert result.data["sigmos_ovrl"] == 3.5 + assert result.data["band_prediction"] == "full_band" + assert result.data["sample_rate"] == 48000 + + +def test_custom_passthrough_keys() -> None: + """User can restrict output to only specific keys.""" + task = _make_task( + { + "audio_filepath": "/a.wav", + "duration": 1.0, + "sigmos_ovrl": 3.0, + "sigmos_noise": 4.0, + "utmos_mos": 4.2, + "book_id": "123", + } + ) + stage = TimestampMapperStage(passthrough_keys=["sigmos_ovrl", "book_id"]) + result = stage.process(task) + + assert result.data["sigmos_ovrl"] == 3.0 + assert result.data["book_id"] == "123" + assert "sigmos_noise" not in result.data + assert "utmos_mos" not in result.data + + +def test_dataset_metadata_not_in_default_output() -> None: + """Dataset-specific keys (text, book_id) are excluded by default passthrough.""" + task = _make_task( + { + "audio_filepath": "/a.wav", + "duration": 1.0, + "text": "hello world", + "book_id": "123", + "reader_id": "456", + "sigmos_ovrl": 3.0, + } + ) + stage = TimestampMapperStage() + result = stage.process(task) - assert "speaker_id" not in result.data - assert result.data["utmos_mos"] == 4.0 + assert result.data["sigmos_ovrl"] == 3.0 + assert "text" not in result.data + assert "book_id" not in result.data + assert "reader_id" not in result.data diff --git a/tests/stages/audio/segmentation/test_speaker_separation.py b/tests/stages/audio/segmentation/test_speaker_separation.py index b067a87b17..ea070d64b0 100644 --- a/tests/stages/audio/segmentation/test_speaker_separation.py +++ b/tests/stages/audio/segmentation/test_speaker_separation.py @@ -43,8 +43,8 @@ def test_process_returns_per_speaker_tasks(self, mock_init: MagicMock) -> None: separator = MagicMock() speaker_data = { - "speaker_0": (_make_audio_segment(3000), 3.0), - "speaker_1": (_make_audio_segment(4000), 4.0), + "speaker_0": (_make_audio_segment(3000), 3.0, [(0.0, 3.0)]), + "speaker_1": (_make_audio_segment(4000), 4.0, [(0.0, 4.0)]), } separator.get_speaker_audio_data.return_value = speaker_data stage._separator = separator @@ -58,7 +58,7 @@ def test_process_returns_per_speaker_tasks(self, mock_init: MagicMock) -> None: assert "speaker_id" in r.data assert "num_speakers" in r.data assert r.data["num_speakers"] == 2 - assert "duration_sec" in r.data + assert "duration" in r.data @patch("nemo_curator.stages.audio.segmentation.speaker_separation.SpeakerSeparationStage._initialize_separator") def test_process_output_keys(self, mock_init: MagicMock) -> None: @@ -66,7 +66,7 @@ def test_process_output_keys(self, mock_init: MagicMock) -> None: separator = MagicMock() separator.get_speaker_audio_data.return_value = { - "spk_0": (_make_audio_segment(5000), 5.0), + "spk_0": (_make_audio_segment(5000), 5.0, [(0.0, 5.0)]), } stage._separator = separator @@ -76,7 +76,7 @@ def test_process_output_keys(self, mock_init: MagicMock) -> None: item = result[0].data assert item["speaker_id"] == "spk_0" assert item["num_speakers"] == 1 - assert item["duration_sec"] == 5.0 + assert item["duration"] == 5.0 assert "waveform" in item assert "sample_rate" in item @@ -86,8 +86,8 @@ def test_min_duration_filters_short_speakers(self, mock_init: MagicMock) -> None separator = MagicMock() separator.get_speaker_audio_data.return_value = { - "speaker_0": (_make_audio_segment(5000), 5.0), - "speaker_1": (_make_audio_segment(1000), 1.0), + "speaker_0": (_make_audio_segment(5000), 5.0, [(0.0, 5.0)]), + "speaker_1": (_make_audio_segment(1000), 1.0, [(0.0, 1.0)]), } stage._separator = separator diff --git a/tests/stages/audio/segmentation/test_vad_segmentation.py b/tests/stages/audio/segmentation/test_vad_segmentation.py index ff51f9a7c8..b4f3d01c24 100644 --- a/tests/stages/audio/segmentation/test_vad_segmentation.py +++ b/tests/stages/audio/segmentation/test_vad_segmentation.py @@ -56,7 +56,7 @@ def test_process_returns_segments(self, mock_load_vad: MagicMock, mock_get_ts: M assert "start_ms" in seg.data assert "end_ms" in seg.data assert "segment_num" in seg.data - assert "duration_sec" in seg.data + assert "duration" in seg.data @patch("nemo_curator.stages.audio.segmentation.vad_segmentation.get_speech_timestamps") @patch("nemo_curator.stages.audio.segmentation.vad_segmentation.load_silero_vad") @@ -79,7 +79,7 @@ def test_process_output_keys(self, mock_load_vad: MagicMock, mock_get_ts: MagicM assert result[0].data["start_ms"] == 0 assert result[0].data["segment_num"] == 0 - assert result[0].data["duration_sec"] > 0 + assert result[0].data["duration"] > 0 assert result[0].data["sample_rate"] == sr @patch("nemo_curator.stages.audio.segmentation.vad_segmentation.get_speech_timestamps") @@ -177,7 +177,7 @@ def test_nested_mode_returns_single_task(self, mock_load_vad: MagicMock, mock_ge assert "start_ms" in seg assert "end_ms" in seg assert "segment_num" in seg - assert "duration_sec" in seg + assert "duration" in seg assert "original_file" in seg @patch("nemo_curator.stages.audio.segmentation.vad_segmentation.get_speech_timestamps") diff --git a/tutorials/audio/readspeech/README.md b/tutorials/audio/readspeech/README.md index 8e7352d945..f48e276fdf 100644 --- a/tutorials/audio/readspeech/README.md +++ b/tutorials/audio/readspeech/README.md @@ -67,14 +67,26 @@ raw_data_dir/ ## Pipeline Architecture +The pipeline supports four topologies based on which features are enabled: + +| Topology | Flags | Output (per input file) | +|----------|-------|------------------------| +| Combo 1 | *(none)* | 1 row with whole-file scores | +| Combo 2 | `--enable-vad` | N rows, one per speech segment | +| Combo 3 | `--enable-speaker-separation` | K rows, one per speaker with diarization timestamps | +| Combo 4 | `--enable-vad --enable-speaker-separation` | K*M rows, one per speaker-segment | + ``` CreateInitialManifestReadSpeechStage - Downloads and scans read_speech directory, parses filenames, creates AudioTask + Downloads and scans read_speech directory, parses filenames | v -AudioDataFilterStage - Mono conversion -> VAD -> Band Filter -> UTMOS -> SIGMOS - -> Speaker Separation -> Timestamp Tracking +AudioDataFilterStage (auto-selects topology) + Combo 1: MonoConversion -> Filters -> TimestampMapper + Combo 2: MonoConversion -> VAD(fan-out) -> Filters -> TimestampMapper + Combo 3: MonoConversion -> Filters -> SpeakerSep(fan-out) -> Filters -> TimestampMapper + Combo 4: MonoConversion -> VAD(nested) -> Filters -> SegmentConcat + -> SpeakerSep -> VAD_Speaker(fan-out) -> Filters -> TimestampMapper | v AudioToDocumentStage -> JsonlWriter @@ -171,74 +183,123 @@ python run.py \ ## Output Format -Results saved to `{output_dir}/*.jsonl`: +Results saved to `{output_dir}/*.jsonl`. The output schema depends on the topology: + +### Core fields (always present) + +| Field | Description | +|-------|-------------| +| `original_file` | Path to the source audio file | +| `original_start_ms` | Start position in original file (ms) | +| `original_end_ms` | End position in original file (ms) | +| `duration_ms` | Duration in milliseconds | +| `duration` | Duration in seconds | + +### Combo 3 additional fields (speaker-only) + +| Field | Description | +|-------|-------------| +| `diar_segments` | List of `[start_sec, end_sec]` pairs for when the speaker talks | +| `speaking_duration` | Total speaking time in seconds (sum of diar_segments) | + +### Passthrough fields (controlled by `passthrough_keys`) + +These fields are copied from the pipeline stages to the output. +By default, all built-in filter scores are included: + +| Field | Source | Default | +|-------|--------|---------| +| `speaker_id` | SpeakerSeparation | included | +| `num_speakers` | SpeakerSeparation | included | +| `sample_rate` | MonoConversion | included | +| `utmos_mos` | UTMOSFilter | included | +| `sigmos_noise`, `sigmos_ovrl`, ... | SIGMOSFilter | included | +| `band_prediction` | BandFilter | included | + +To customize which fields appear in output, set `passthrough_keys` in the config: +```python +AudioDataFilterStage(config={ + "timestamp_mapper": { + "passthrough_keys": ["utmos_mos", "sigmos_ovrl"], # only these + }, +}) +``` + +**Safety**: Non-serializable fields (`waveform`, `audio`, `segments`, etc.) +are always blocked, even if added to `passthrough_keys`. +A warning is logged if blocked keys are detected in the configuration. + +### Example outputs + +**Combo 1** (no VAD, no speaker): ```json -{ - "audio_filepath": "/path/to/read_speech/book_00000_chp_0009_reader_06709_0_seg_1_seg1.wav", - "sample_rate": 48000, - "book_id": "00000", - "reader_id": "06709", - "original_start_ms": 1500, - "original_end_ms": 5200, - "duration_ms": 3700, - "duration_sec": 3.7, - "speaker_id": "speaker_0", - "utmos_mos": 3.9, - "sigmos_noise": 4.2, - "band_prediction": "full_band" -} +{"original_file": "/path/to/file.wav", "original_start_ms": 0, "original_end_ms": 10500, "duration_ms": 10500, "duration": 10.5, "utmos_mos": 3.9, "sigmos_ovrl": 3.5} +``` + +**Combo 2** (VAD only): +```json +{"original_file": "/path/to/file.wav", "original_start_ms": 5200, "original_end_ms": 13200, "duration_ms": 8000, "duration": 8.0, "utmos_mos": 4.1, "sigmos_ovrl": 3.7} +``` + +**Combo 3** (speaker only): +```json +{"original_file": "/path/to/file.wav", "original_start_ms": 5200, "original_end_ms": 120500, "duration_ms": 115300, "duration": 115.3, "speaking_duration": 43.4, "diar_segments": [[5.2, 15.4], [30.1, 42.8], [100.0, 120.5]], "speaker_id": "speaker_0", "num_speakers": 3} +``` + +**Combo 4** (VAD + speaker): +```json +{"original_file": "/path/to/file.wav", "original_start_ms": 7200, "original_end_ms": 11200, "duration_ms": 4000, "duration": 4.0, "speaker_id": "speaker_0", "num_speakers": 3, "utmos_mos": 4.2} ``` ## Extracting Audio Segments -After the pipeline produces a `manifest.jsonl`, use `extract_segments.py` to extract the actual audio segments from the original files based on the timestamps in the manifest. +After the pipeline produces a `manifest.jsonl`, use `extract_segments.py` to extract the actual audio segments from the original files. The script auto-detects the pipeline topology from the manifest schema. ### Basic Usage ```bash -# Extract segments from a single manifest file -python extract_segments.py \ - --manifest ./dns_data/result/manifest.jsonl \ - --output-dir ./extracted_segments +# Extract from a single manifest file +python extract_segments.py -m ./dns_data/result/manifest.jsonl -o ./extracted/ # Extract from a directory of jsonl files (auto-combines them) -python extract_segments.py \ - --manifest ./dns_data/result/ \ - --output-dir ./extracted_segments - -# Output as FLAC instead of WAV -python extract_segments.py \ - --manifest ./dns_data/result/manifest.jsonl \ - --output-dir ./extracted_segments \ - --output-format flac +python extract_segments.py -m ./dns_data/result/ -o ./extracted/ + +# Output as FLAC +python extract_segments.py -m ./dns_data/result/ -o ./extracted/ -f flac ``` -### Options +### Extraction per topology -| Option | Default | Description | -|--------|---------|-------------| -| `--manifest, -m` | required | Path to manifest.jsonl or directory of .jsonl files | -| `--output-dir, -o` | required | Directory for extracted audio segments | -| `--output-format, -f` | `wav` | Output format: `wav`, `flac`, or `ogg` (via soundfile) | -| `--verbose, -v` | `false` | Enable verbose (DEBUG) logging | - -### Output +| Topology | What it extracts | File naming | +|----------|-----------------|-------------| +| Combo 1 | Full file copy | `{name}.wav` | +| Combo 2 | Each VAD segment | `{name}_segment_000.wav` | +| Combo 3 | Each speaking interval per speaker | `{name}_speaker_0_segment_000.wav` | +| Combo 4 | Each speaker-segment | `{name}_speaker_0_segment_000.wav` | -Extracted files are named based on the original filename with speaker and segment info: +### Output files ``` -extracted_segments/ -├── book_00025_chp_0019_reader_04069_speaker_0_segment_000.wav -├── book_00025_chp_0019_reader_04069_speaker_0_segment_001.wav -├── book_00025_chp_0019_reader_04069_speaker_1_segment_000.wav -├── manifest.jsonl # Combined manifest (when input is a directory) -└── extraction_summary.json # Statistics summary +extracted/ +├── {name}_speaker_0_segment_000.wav # Audio segments +├── {name}_speaker_0_segment_001.wav +├── metadata.csv # Per-segment metadata with quality scores +├── manifest.jsonl # Combined manifest (when input is a directory) +└── extraction_summary.json # Statistics summary ``` -Without speaker separation, files are named `{original_name}_segment_{num}.wav`. +The `metadata.csv` contains one row per extracted segment with columns: +`filename`, `original_file`, `start_sec`, `end_sec`, `duration`, and all quality scores from the manifest. -The script also generates an `extraction_summary.json` with statistics including total segments extracted, total duration, and per-speaker segment counts. +### Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--manifest, -m` | required | Path to manifest.jsonl or directory of .jsonl files | +| `--output-dir, -o` | required | Directory for extracted audio segments | +| `--output-format, -f` | `wav` | Output format: `wav`, `flac`, or `ogg` | +| `--verbose, -v` | `false` | Enable verbose (DEBUG) logging | > **Note**: Supported output formats are `wav`, `flac`, and `ogg` via `soundfile`. diff --git a/tutorials/audio/readspeech/extract_segments.py b/tutorials/audio/readspeech/extract_segments.py index df666acdb4..a65b71d9ea 100755 --- a/tutorials/audio/readspeech/extract_segments.py +++ b/tutorials/audio/readspeech/extract_segments.py @@ -16,29 +16,46 @@ Segment Extraction Script Reads manifest jsonl file(s) and extracts audio segments from original files. -Each segment is saved with naming convention: - With speaker separation: {original_filename}_speaker_{x}_segment_{y}.{format} - Without speaker separation: {original_filename}_segment_{y}.{format} +Automatically detects the pipeline combo from the manifest schema and applies +the appropriate extraction strategy: + + Combo 1 (no VAD, no speaker): + Copies the full original file as-is. + Output: {original_filename}.{format} + + Combo 2 (VAD only): + Extracts each VAD segment by original_start_ms / original_end_ms. + Output: {original_filename}_segment_{NNN}.{format} + Segments are numbered in ascending order of start time. + + Combo 3 (speaker only): + Extracts each speaking interval from diar_segments per speaker. + Output: {original_filename}_speaker_{X}_segment_{NNN}.{format} + Segments are numbered per speaker in ascending order. + + Combo 4 (VAD + speaker): + Extracts each speaker-segment by original_start_ms / original_end_ms. + Output: {original_filename}_speaker_{X}_segment_{NNN}.{format} + Segments are numbered per speaker in ascending order of start time. Input can be: - A single manifest.jsonl file - - A directory containing multiple .jsonl files (from pipeline executor output) - -When given a directory, all .jsonl files are combined into a single -manifest.jsonl in the output directory with escaped paths (\\/) cleaned up. + - A directory containing multiple .jsonl files Supports configurable output format: wav, flac, ogg (via soundfile). Usage: - python extract_segments.py --manifest manifest.jsonl --output-dir extracted_segments/ + python extract_segments.py --manifest manifest.jsonl --output-dir extracted/ python extract_segments.py --manifest /path/to/result_dir/ --output-dir out/ - python extract_segments.py --manifest /path/to/result_dir/ --output-dir out/ --output-format flac + python extract_segments.py --manifest result_dir/ --output-dir out/ --output-format flac """ import argparse +import csv import glob import json import os +import shutil import sys from collections import defaultdict from pathlib import Path @@ -55,31 +72,53 @@ "ogg": "VORBIS", } +_CSV_STRUCTURAL_KEYS = frozenset( + { + "filename", + "original_file", + "original_start_ms", + "original_end_ms", + "duration_ms", + "start_sec", + "end_sec", + "duration", + "segment_index", + "speaker_id", + "num_speakers", + "speaking_duration", + "diar_segments", + } +) + + +def _extract_scores(entry: dict) -> dict: + """Extract quality/filter score fields from a manifest entry. + + Returns all keys that are not structural CSV columns (timestamps, + duration, speaker info), with float values rounded for readability. + Since TimestampMapper already whitelist-filters the manifest output, + anything remaining is a quality score or user-defined field. + """ + return {k: round(v, 4) if isinstance(v, float) else v for k, v in entry.items() if k not in _CSV_STRUCTURAL_KEYS} def load_manifest(manifest_path: str) -> list: - """Load a single manifest.jsonl file and return list of segment entries.""" - segments = [] + """Load a single manifest.jsonl file and return list of entries.""" + entries = [] with open(manifest_path) as f: for line_num, raw_line in enumerate(f, 1): line = raw_line.strip() if not line: continue try: - segment = json.loads(line) - segments.append(segment) + entries.append(json.loads(line)) except json.JSONDecodeError as e: logger.warning(f"Failed to parse line {line_num} in {manifest_path}: {e}") - return segments + return entries def load_manifests(input_path: str, output_dir: str) -> list: - """ - Load segments from a single jsonl file or a directory of jsonl files. - - When input_path is a directory, all .jsonl files are combined and a - merged manifest.jsonl is saved in output_dir with escaped paths fixed. - """ + """Load entries from a single jsonl file or a directory of jsonl files.""" if os.path.isfile(input_path): return load_manifest(input_path) @@ -94,157 +133,385 @@ def load_manifests(input_path: str, output_dir: str) -> list: logger.info(f"Found {len(jsonl_files)} jsonl files in {input_path}") - all_segments = [] - skipped_files = 0 + all_entries = [] for jf in jsonl_files: - segs = load_manifest(jf) - if not segs: - skipped_files += 1 - continue - all_segments.extend(segs) + all_entries.extend(load_manifest(jf)) - if skipped_files: - logger.info(f"Skipped {skipped_files} empty jsonl file(s)") - logger.info(f"Combined {len(all_segments)} segments from {len(jsonl_files) - skipped_files} file(s)") + logger.info(f"Combined {len(all_entries)} entries from {len(jsonl_files)} file(s)") - if all_segments: + if all_entries: os.makedirs(output_dir, exist_ok=True) combined_path = os.path.join(output_dir, "manifest.jsonl") with open(combined_path, "w") as f: - f.writelines(json.dumps(seg) + "\n" for seg in all_segments) + f.writelines(json.dumps(e) + "\n" for e in all_entries) logger.info(f"Saved combined manifest to {combined_path}") - return all_segments + return all_entries + +def detect_combo(entries: list) -> int: + """Detect which pipeline combo produced the manifest. + + Returns 1, 2, 3, or 4. + """ + if not entries: + return 1 -def _write_segment( - output_path: str, segment_audio: np.ndarray, sample_rate: int, output_format: str -) -> None: + first = entries[0] + has_speaker = "speaker_id" in first + has_diar = "diar_segments" in first + has_timestamps = "original_start_ms" in first and "original_end_ms" in first + + if has_speaker and has_diar: + return 3 + if has_speaker and has_timestamps: + return 4 + if has_timestamps and not has_speaker: + return 2 + return 1 + + +def _write_segment(output_path: str, audio: np.ndarray, sample_rate: int, output_format: str) -> None: """Write a single audio segment to disk.""" - sf.write(output_path, segment_audio, sample_rate, subtype=SOUNDFILE_FORMATS[output_format]) - - -def _process_file_segments( - original_file: str, - file_segments: list, - output_dir: str, - output_format: str, - speaker_counts: dict, -) -> tuple[int, float]: - """Process all segments for a single original file. Returns (extracted_count, duration_sec).""" - original_name = Path(original_file).stem - logger.info(f"\nProcessing: {original_name}") - logger.info(f" Original file: {original_file}") - logger.info(f" Segments to extract: {len(file_segments)}") - - try: - file_info = sf.info(original_file) - sample_rate = file_info.samplerate - total_samples = file_info.frames - logger.info(f" Original duration: {total_samples / sample_rate:.2f}s") - except Exception as e: # noqa: BLE001 - logger.error(f" Failed to read audio info: {e}") - return 0, 0.0 - - has_speakers = any("speaker_id" in seg for seg in file_segments) - if has_speakers: - file_segments.sort(key=lambda x: (x.get("speaker_id", ""), x.get("original_start_ms", 0))) - else: - file_segments.sort(key=lambda x: x.get("original_start_ms", 0)) - - segment_counts = defaultdict(int) + sf.write(output_path, audio, sample_rate, subtype=SOUNDFILE_FORMATS[output_format]) + + +def _read_segment(filepath: str, start_ms: int, end_ms: int, sample_rate: int) -> np.ndarray: + """Read a slice of audio from a file.""" + start_sample = int(start_ms * sample_rate / 1000) + end_sample = int(end_ms * sample_rate / 1000) + audio, _ = sf.read(filepath, start=start_sample, stop=end_sample, dtype="float32") + return audio + + +# ------------------------------------------------------------------ +# Combo 1: no VAD, no speaker -- copy full file +# ------------------------------------------------------------------ + + +def extract_combo1( + entries: list, output_dir: str, output_format: str +) -> tuple[int, float, dict[str, int], list[dict]]: + """Copy the full original file(s) as-is.""" extracted = 0 - duration_total = 0.0 + total_dur = 0.0 + metadata_rows: list[dict] = [] + speaker_counts: dict[str, int] = {} - for seg in file_segments: - start_ms = seg.get("original_start_ms", 0) - end_ms = seg.get("original_end_ms", 0) - speaker_id = seg.get("speaker_id") - duration_sec = seg.get("duration_sec", (end_ms - start_ms) / 1000) + for entry in entries: + original_file = entry.get("original_file") or entry.get("audio_filepath") + if not original_file or not os.path.exists(original_file): + logger.error(f"Original file not found: {original_file}") + continue - count_key = speaker_id or "__all__" - segment_num = segment_counts[count_key] - segment_counts[count_key] += 1 + original_name = Path(original_file).stem + out_filename = f"{original_name}.{output_format}" + output_path = os.path.join(output_dir, out_filename) - if speaker_id: - speaker_num = speaker_id.replace("speaker_", "") if "speaker_" in speaker_id else speaker_id - output_filename = f"{original_name}_speaker_{speaker_num}_segment_{segment_num:03d}.{output_format}" + if output_format == "wav" and original_file.endswith(".wav"): + shutil.copy2(original_file, output_path) else: - output_filename = f"{original_name}_segment_{segment_num:03d}.{output_format}" - output_path = os.path.join(output_dir, output_filename) - - try: - start_sample = int(start_ms * sample_rate / 1000) - end_sample = int(end_ms * sample_rate / 1000) - segment_audio, _ = sf.read( - original_file, start=start_sample, stop=end_sample, dtype="float32", - ) - _write_segment(output_path, segment_audio, sample_rate, output_format) - extracted += 1 - duration_total += duration_sec - if speaker_id: - speaker_counts[speaker_id] += 1 - logger.debug(f" Extracted: {output_filename} ({duration_sec:.2f}s)") - except Exception as e: # noqa: BLE001 - logger.error(f" Failed to extract segment {segment_num}: {e}") + audio, sr = sf.read(original_file, dtype="float32") + _write_segment(output_path, audio, sr, output_format) + + dur = entry.get("duration", 0) + total_dur += dur + extracted += 1 + logger.info(f" Copied: {output_path} ({dur:.1f}s)") + + metadata_rows.append( + { + "filename": out_filename, + "original_file": original_file, + "start_sec": 0.0, + "end_sec": round(dur, 3), + "duration": round(dur, 3), + **_extract_scores(entry), + } + ) - logger.info(f" Extracted {sum(segment_counts.values())} segments from this file") - return extracted, duration_total + return extracted, total_dur, speaker_counts, metadata_rows -def extract_segments(input_path: str, output_dir: str, output_format: str = DEFAULT_OUTPUT_FORMAT) -> None: - """ - Extract segments from original audio files based on manifest. +# ------------------------------------------------------------------ +# Combo 2: VAD only -- extract each segment by timestamps +# ------------------------------------------------------------------ - Args: - input_path: Path to manifest.jsonl file or directory of .jsonl files - output_dir: Directory to save extracted segments - output_format: Output audio format (wav, flac, ogg). Default: wav. - """ - os.makedirs(output_dir, exist_ok=True) - logger.info(f"Loading manifest: {input_path}") - segments = load_manifests(input_path, output_dir) - logger.info(f"Found {len(segments)} segments total") +def extract_combo2( + entries: list, output_dir: str, output_format: str +) -> tuple[int, float, dict[str, int], list[dict]]: + """Extract VAD segments sorted by start time.""" + by_file = defaultdict(list) + for entry in entries: + original_file = entry.get("original_file", "") + by_file[original_file].append(entry) - if not segments: - logger.error("No segments found in manifest") - return + extracted = 0 + total_dur = 0.0 + speaker_counts: dict[str, int] = {} + metadata_rows: list[dict] = [] + + for original_file, segments in by_file.items(): + if not os.path.exists(original_file): + logger.error(f"Original file not found: {original_file}") + continue + + info = sf.info(original_file) + original_name = Path(original_file).stem + logger.info(f"\nProcessing: {original_name} ({len(segments)} segments)") + + segments.sort(key=lambda x: x.get("original_start_ms", 0)) - segments_by_file = defaultdict(list) - for seg in segments: - original_file = seg.get("original_file") - if original_file: - segments_by_file[original_file].append(seg) + for i, seg in enumerate(segments): + start_ms = seg.get("original_start_ms", 0) + end_ms = seg.get("original_end_ms", 0) + dur = seg.get("duration", (end_ms - start_ms) / 1000) - logger.info(f"Segments span {len(segments_by_file)} original file(s)") + out_filename = f"{original_name}_segment_{i:03d}.{output_format}" + output_path = os.path.join(output_dir, out_filename) - total_extracted = 0 - total_duration_sec = 0.0 + try: + audio = _read_segment(original_file, start_ms, end_ms, info.samplerate) + _write_segment(output_path, audio, info.samplerate, output_format) + extracted += 1 + total_dur += dur + logger.debug(f" {out_filename} ({start_ms}-{end_ms}ms, {dur:.2f}s)") + + metadata_rows.append( + { + "filename": out_filename, + "original_file": original_file, + "segment_index": i, + "start_sec": round(start_ms / 1000, 3), + "end_sec": round(end_ms / 1000, 3), + "duration": round(dur, 3), + **_extract_scores(seg), + } + ) + except Exception as e: # noqa: BLE001 + logger.error(f" Failed to extract {out_filename}: {e}") + + return extracted, total_dur, speaker_counts, metadata_rows + + +# ------------------------------------------------------------------ +# Combo 3: speaker only -- extract each diar_segment per speaker +# ------------------------------------------------------------------ + + +def extract_combo3(entries: list, output_dir: str, output_format: str) -> tuple[int, float, dict, list[dict]]: + """Extract individual speaking intervals from diar_segments per speaker.""" + by_file = defaultdict(list) + for entry in entries: + original_file = entry.get("original_file", "") + by_file[original_file].append(entry) + + extracted = 0 + total_dur = 0.0 speaker_counts: dict[str, int] = defaultdict(int) + metadata_rows: list[dict] = [] - for original_file, file_segments in segments_by_file.items(): + for original_file, speaker_entries in by_file.items(): if not os.path.exists(original_file): logger.error(f"Original file not found: {original_file}") continue - extracted, duration = _process_file_segments( - original_file, - file_segments, - output_dir, - output_format, - speaker_counts, - ) - total_extracted += extracted - total_duration_sec += duration + + info = sf.info(original_file) + original_name = Path(original_file).stem + logger.info(f"\nProcessing: {original_name} ({len(speaker_entries)} speakers)") + + speaker_entries.sort(key=lambda x: x.get("speaker_id", "")) + + for entry in speaker_entries: + speaker_id = entry.get("speaker_id", "unknown") + speaker_num = speaker_id.replace("speaker_", "") if "speaker_" in speaker_id else speaker_id + num_speakers = entry.get("num_speakers", 0) + diar_segments = entry.get("diar_segments", []) + + scores = _extract_scores(entry) + + if not diar_segments: + logger.warning(f" {speaker_id}: no diar_segments, skipping") + continue + + diar_segments_sorted = sorted(diar_segments, key=lambda x: x[0]) + + logger.info(f" {speaker_id}: {len(diar_segments_sorted)} speaking intervals") + + for j, (start_sec, end_sec) in enumerate(diar_segments_sorted): + start_ms = int(start_sec * 1000) + end_ms = int(end_sec * 1000) + dur = end_sec - start_sec + + out_filename = f"{original_name}_speaker_{speaker_num}_segment_{j:03d}.{output_format}" + output_path = os.path.join(output_dir, out_filename) + + try: + audio = _read_segment(original_file, start_ms, end_ms, info.samplerate) + _write_segment(output_path, audio, info.samplerate, output_format) + extracted += 1 + total_dur += dur + speaker_counts[speaker_id] += 1 + logger.debug(f" {out_filename} ({start_sec:.2f}-{end_sec:.2f}s, {dur:.2f}s)") + + metadata_rows.append( + { + "filename": out_filename, + "original_file": original_file, + "speaker_id": speaker_id, + "num_speakers": num_speakers, + "segment_index": j, + "start_sec": round(start_sec, 3), + "end_sec": round(end_sec, 3), + "duration": round(dur, 3), + **scores, + } + ) + except Exception as e: # noqa: BLE001 + logger.error(f" Failed to extract {out_filename}: {e}") + + return extracted, total_dur, speaker_counts, metadata_rows + + +# ------------------------------------------------------------------ +# Combo 4: VAD + speaker -- extract each speaker-segment by timestamps +# ------------------------------------------------------------------ + + +def extract_combo4(entries: list, output_dir: str, output_format: str) -> tuple[int, float, dict, list[dict]]: + """Extract speaker-segments using original_start_ms / original_end_ms.""" + by_file = defaultdict(list) + for entry in entries: + original_file = entry.get("original_file", "") + by_file[original_file].append(entry) + + extracted = 0 + total_dur = 0.0 + speaker_counts: dict[str, int] = defaultdict(int) + metadata_rows: list[dict] = [] + + for original_file, segments in by_file.items(): + if not os.path.exists(original_file): + logger.error(f"Original file not found: {original_file}") + continue + + info = sf.info(original_file) + original_name = Path(original_file).stem + logger.info(f"\nProcessing: {original_name} ({len(segments)} speaker-segments)") + + segments.sort(key=lambda x: (x.get("speaker_id", ""), x.get("original_start_ms", 0))) + + per_speaker_count: dict[str, int] = defaultdict(int) + + for seg in segments: + speaker_id = seg.get("speaker_id", "unknown") + speaker_num = speaker_id.replace("speaker_", "") if "speaker_" in speaker_id else speaker_id + num_speakers = seg.get("num_speakers", 0) + start_ms = seg.get("original_start_ms", 0) + end_ms = seg.get("original_end_ms", 0) + dur = seg.get("duration", (end_ms - start_ms) / 1000) + + seg_idx = per_speaker_count[speaker_id] + per_speaker_count[speaker_id] += 1 + + out_filename = f"{original_name}_speaker_{speaker_num}_segment_{seg_idx:03d}.{output_format}" + output_path = os.path.join(output_dir, out_filename) + + try: + audio = _read_segment(original_file, start_ms, end_ms, info.samplerate) + _write_segment(output_path, audio, info.samplerate, output_format) + extracted += 1 + total_dur += dur + speaker_counts[speaker_id] += 1 + logger.debug(f" {out_filename} ({start_ms}-{end_ms}ms, {dur:.2f}s)") + + metadata_rows.append( + { + "filename": out_filename, + "original_file": original_file, + "speaker_id": speaker_id, + "num_speakers": num_speakers, + "segment_index": seg_idx, + "start_sec": round(start_ms / 1000, 3), + "end_sec": round(end_ms / 1000, 3), + "duration": round(dur, 3), + **_extract_scores(seg), + } + ) + except Exception as e: # noqa: BLE001 + logger.error(f" Failed to extract {out_filename}: {e}") + + return extracted, total_dur, speaker_counts, metadata_rows + + +# ------------------------------------------------------------------ +# Main +# ------------------------------------------------------------------ + + +def _write_metadata_csv(output_dir: str, metadata_rows: list[dict]) -> str: + """Write metadata.csv from collected metadata rows.""" + if not metadata_rows: + return "" + + all_keys: list[str] = [] + seen: set[str] = set() + for row in metadata_rows: + for k in row: + if k not in seen: + all_keys.append(k) + seen.add(k) + + csv_path = os.path.join(output_dir, "metadata.csv") + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=all_keys) + writer.writeheader() + writer.writerows(metadata_rows) + + return csv_path + + +def extract_segments(input_path: str, output_dir: str, output_format: str = DEFAULT_OUTPUT_FORMAT) -> None: + """Extract segments from original audio files based on manifest.""" + os.makedirs(output_dir, exist_ok=True) + + logger.info(f"Loading manifest: {input_path}") + entries = load_manifests(input_path, output_dir) + logger.info(f"Found {len(entries)} entries total") + + if not entries: + logger.error("No entries found in manifest") + return + + combo = detect_combo(entries) + combo_names = { + 1: "Combo 1 (no VAD, no speaker) -- copy full files", + 2: "Combo 2 (VAD only) -- extract VAD segments", + 3: "Combo 3 (speaker only) -- extract diar_segments per speaker", + 4: "Combo 4 (VAD + speaker) -- extract speaker-segments by timestamps", + } + logger.info(f"Detected: {combo_names[combo]}") + + extractors = { + 1: extract_combo1, + 2: extract_combo2, + 3: extract_combo3, + 4: extract_combo4, + } + total_extracted, total_dur, speaker_counts, metadata_rows = extractors[combo](entries, output_dir, output_format) + + csv_path = _write_metadata_csv(output_dir, metadata_rows) summary = { "manifest_path": input_path, "output_dir": output_dir, "total_segments": total_extracted, - "total_duration_sec": round(total_duration_sec, 2), - "segments_by_speaker": dict(speaker_counts), - "original_files_processed": len(segments_by_file), + "total_duration_sec": round(total_dur, 2), + "output_format": output_format, } + if speaker_counts: + summary["segments_by_speaker"] = dict(speaker_counts) summary_path = os.path.join(output_dir, "extraction_summary.json") with open(summary_path, "w") as f: @@ -253,14 +520,18 @@ def extract_segments(input_path: str, output_dir: str, output_format: str = DEFA logger.info(f"\n{'=' * 60}") logger.info("EXTRACTION COMPLETE") logger.info(f"{'=' * 60}") - logger.info(f"Total segments extracted: {total_extracted}") - logger.info(f"Total duration: {total_duration_sec:.2f}s ({total_duration_sec / 60:.1f} min)") - logger.info(f"Output directory: {output_dir}") + logger.info(f" Combo: {combo_names[combo]}") + logger.info(f" Total segments: {total_extracted}") + logger.info(f" Total duration: {total_dur:.2f}s ({total_dur / 60:.1f} min)") + logger.info(f" Output: {output_dir}") + logger.info(f" Format: {output_format}") if speaker_counts: - logger.info("\nSegments by speaker:") + logger.info(" Segments by speaker:") for speaker, count in sorted(speaker_counts.items()): - logger.info(f" {speaker}: {count} segments") - logger.info(f"\nSummary saved to: {summary_path}") + logger.info(f" {speaker}: {count} segments") + if csv_path: + logger.info(f" Metadata CSV: {csv_path}") + logger.info(f" Summary: {summary_path}") def main() -> int: @@ -275,13 +546,11 @@ def main() -> int: type=str, default=DEFAULT_OUTPUT_FORMAT, choices=["wav", "flac", "ogg"], - help="Output audio format (default: wav).", + help="Output audio format (default: wav)", ) parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging") - args = parser.parse_args() - # Configure logging if args.verbose: logger.remove() logger.add(lambda msg: print(msg, end=""), level="DEBUG") @@ -292,7 +561,6 @@ def main() -> int: logger.info(f"Output format: {args.output_format}") extract_segments(input_path=args.manifest, output_dir=args.output_dir, output_format=args.output_format) - return 0 diff --git a/tutorials/audio/readspeech/pipeline.py b/tutorials/audio/readspeech/pipeline.py index 55dc6fb533..de25ddaa3a 100644 --- a/tutorials/audio/readspeech/pipeline.py +++ b/tutorials/audio/readspeech/pipeline.py @@ -21,10 +21,30 @@ Dataset: Microsoft DNS Challenge 5 - Read Speech (Track 1 Headset) Source: https://github.com/microsoft/DNS-Challenge -The pipeline: -1. Creates initial manifest from read_speech WAV files (14,279 files at 48kHz) -2. Applies AudioDataFilterStage (VAD, quality filters, speaker separation) -3. Outputs filtered manifest with quality scores and timestamps +Pipeline supports four topologies depending on which features are enabled: + + Combo 1 (default, no flags): + MonoConversion -> Filters -> TimestampMapper -> JsonlWriter + Output: 1 row per file with whole-file quality scores. + + Combo 2 (--enable-vad): + MonoConversion -> VAD(fan-out) -> Filters -> TimestampMapper -> JsonlWriter + Output: 1 row per speech segment with per-segment scores and timestamps. + + Combo 3 (--enable-speaker-separation): + MonoConversion -> Filters -> SpeakerSep(fan-out) -> Filters -> TimestampMapper + Output: 1 row per speaker with diarization timestamps and per-speaker scores. + + Combo 4 (--enable-vad --enable-speaker-separation): + Full pipeline with SegmentConcat + TimestampMapper remapping. + Output: 1 row per speaker-segment with precise timestamps. + +Output control: + TimestampMapper uses a whitelist (passthrough_keys) to control which + fields appear in the JSONL output. The default includes all built-in + filter scores (UTMOS, SIGMOS, BandFilter) and speaker metadata. + Non-serializable fields (waveform, segments) are always blocked. + To customize, set "passthrough_keys" in the timestamp_mapper config. Example: python pipeline.py --raw_data_dir /path/to/read_speech --enable-utmos --enable-vad @@ -102,16 +122,7 @@ def create_readspeech_pipeline(args: argparse.Namespace) -> Pipeline: "exclude_overlaps": args.speaker_exclude_overlaps, "min_duration": args.speaker_min_duration, }, - "timestamp_mapper": { - "passthrough_keys": [ - "band_prediction", - "utmos_mos", - "sigmos_noise", - "sigmos_ovrl", - "speaker_id", - "num_speakers", - ], - }, + "timestamp_mapper": {}, } ) ) @@ -182,6 +193,13 @@ def _build_parser() -> argparse.ArgumentParser: default="xenna", help="Execution backend: 'xenna' (default) or 'ray_data'", ) + parser.add_argument( + "--execution-mode", + type=str, + choices=["streaming", "batch"], + default="streaming", + help="Xenna execution mode: 'streaming' (default) or 'batch'", + ) parser.add_argument("--verbose", action="store_true", help="Verbose logging") parser.add_argument("--enable-vad", action="store_true", help="Enable VAD segmentation") parser.add_argument("--vad-min-duration", type=float, default=2.0, help="Min VAD segment (sec)") @@ -260,7 +278,11 @@ def main() -> None: logger.info("Starting pipeline execution...") try: - executor = RayDataExecutor() if args.backend == "ray_data" else XennaExecutor(config={"execution_mode": "streaming"}) + executor = ( + RayDataExecutor() + if args.backend == "ray_data" + else XennaExecutor(config={"execution_mode": args.execution_mode}) + ) pipeline.run(executor) logger.info(f"Results written to {args.output_dir}/*.jsonl") diff --git a/tutorials/audio/readspeech/pipeline.yaml b/tutorials/audio/readspeech/pipeline.yaml index c5cb1604c7..bbad9e2ad6 100644 --- a/tutorials/audio/readspeech/pipeline.yaml +++ b/tutorials/audio/readspeech/pipeline.yaml @@ -60,14 +60,6 @@ enable_speaker_separation: true speaker_exclude_overlaps: true speaker_min_duration: 0.8 -passthrough_keys: - - band_prediction - - utmos_mos - - sigmos_noise - - sigmos_ovrl - - speaker_id - - num_speakers - processors: - _target_: nemo_curator.stages.audio.datasets.readspeech.CreateInitialManifestReadSpeechStage raw_data_dir: ${raw_data_dir} @@ -100,8 +92,7 @@ processors: enable: ${enable_speaker_separation} exclude_overlaps: ${speaker_exclude_overlaps} min_duration: ${speaker_min_duration} - timestamp_mapper: - passthrough_keys: ${passthrough_keys} + timestamp_mapper: {} - _target_: nemo_curator.stages.audio.io.convert.AudioToDocumentStage From 77e0bab32c3af23b921bed7e21793ad628aa84b1 Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Thu, 9 Apr 2026 05:58:44 +0000 Subject: [PATCH 02/11] Address PR review: fix unreachable combo detection, clarify core vs passthrough fields - Remove unreachable extract_combo1 (TimestampMapper always sets timestamps) - Merge combos 1 & 2 into extract_segments_by_timestamps - Rename extraction functions to be descriptive - Document diar_segments/speaking_duration as core fields in TimestampMapper docstring --- .../audio/postprocessing/timestamp_mapper.py | 13 ++- .../audio/readspeech/extract_segments.py | 98 +++++-------------- 2 files changed, 35 insertions(+), 76 deletions(-) diff --git a/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py b/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py index 622602513f..66b35a8eed 100644 --- a/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py +++ b/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py @@ -106,9 +106,15 @@ class TimestampMapperStage(ProcessingStage[AudioTask, AudioTask]): """ Normalize task data at the pipeline output boundary. - Constructs core output fields (``original_file``, timestamps, - duration) from available timing sources, then copies only the - keys listed in ``passthrough_keys`` from the input. + Constructs core output fields from available timing sources, + then copies only the keys listed in ``passthrough_keys`` from + the input. + + Core fields (always present, not controlled by passthrough_keys): + ``original_file``, ``original_start_ms``, ``original_end_ms``, + ``duration_ms``, ``duration``. + When diarization segments are available: ``diar_segments``, + ``speaking_duration`` are also set as core fields. Args: passthrough_keys: Keys to copy from input to output. @@ -213,6 +219,7 @@ def _build_output_item_no_mapping(self, item: dict[str, Any]) -> dict[str, Any]: diar_segments = item.get("diar_segments") if diar_segments and len(diar_segments) > 0: + diar_segments = sorted(diar_segments, key=lambda x: x[0]) first_start = diar_segments[0][0] last_end = diar_segments[-1][1] result["original_start_ms"] = int(first_start * 1000) diff --git a/tutorials/audio/readspeech/extract_segments.py b/tutorials/audio/readspeech/extract_segments.py index a65b71d9ea..c2260fe2db 100755 --- a/tutorials/audio/readspeech/extract_segments.py +++ b/tutorials/audio/readspeech/extract_segments.py @@ -20,11 +20,11 @@ the appropriate extraction strategy: Combo 1 (no VAD, no speaker): - Copies the full original file as-is. - Output: {original_filename}.{format} + Extracts the full file as a single segment (start=0, end=file duration). + Output: {original_filename}_segment_000.{format} Combo 2 (VAD only): - Extracts each VAD segment by original_start_ms / original_end_ms. + Extracts each VAD speech segment by original_start_ms / original_end_ms. Output: {original_filename}_segment_{NNN}.{format} Segments are numbered in ascending order of start time. @@ -55,7 +55,6 @@ import glob import json import os -import shutil import sys from collections import defaultdict from pathlib import Path @@ -152,23 +151,27 @@ def load_manifests(input_path: str, output_dir: str) -> list: def detect_combo(entries: list) -> int: """Detect which pipeline combo produced the manifest. - Returns 1, 2, 3, or 4. + Returns 2, 3, or 4. Since TimestampMapper always emits + ``original_start_ms``/``original_end_ms``, combos 1 and 2 are + indistinguishable and both use timestamp-based extraction. + + Returns: + 2: segments by timestamps (combos 1 and 2) + 3: speaker diarization segments + 4: speaker-segments by timestamps """ if not entries: - return 1 + return 2 first = entries[0] has_speaker = "speaker_id" in first has_diar = "diar_segments" in first - has_timestamps = "original_start_ms" in first and "original_end_ms" in first if has_speaker and has_diar: return 3 - if has_speaker and has_timestamps: + if has_speaker: return 4 - if has_timestamps and not has_speaker: - return 2 - return 1 + return 2 def _write_segment(output_path: str, audio: np.ndarray, sample_rate: int, output_format: str) -> None: @@ -185,63 +188,14 @@ def _read_segment(filepath: str, start_ms: int, end_ms: int, sample_rate: int) - # ------------------------------------------------------------------ -# Combo 1: no VAD, no speaker -- copy full file -# ------------------------------------------------------------------ - - -def extract_combo1( - entries: list, output_dir: str, output_format: str -) -> tuple[int, float, dict[str, int], list[dict]]: - """Copy the full original file(s) as-is.""" - extracted = 0 - total_dur = 0.0 - metadata_rows: list[dict] = [] - speaker_counts: dict[str, int] = {} - - for entry in entries: - original_file = entry.get("original_file") or entry.get("audio_filepath") - if not original_file or not os.path.exists(original_file): - logger.error(f"Original file not found: {original_file}") - continue - - original_name = Path(original_file).stem - out_filename = f"{original_name}.{output_format}" - output_path = os.path.join(output_dir, out_filename) - - if output_format == "wav" and original_file.endswith(".wav"): - shutil.copy2(original_file, output_path) - else: - audio, sr = sf.read(original_file, dtype="float32") - _write_segment(output_path, audio, sr, output_format) - - dur = entry.get("duration", 0) - total_dur += dur - extracted += 1 - logger.info(f" Copied: {output_path} ({dur:.1f}s)") - - metadata_rows.append( - { - "filename": out_filename, - "original_file": original_file, - "start_sec": 0.0, - "end_sec": round(dur, 3), - "duration": round(dur, 3), - **_extract_scores(entry), - } - ) - - return extracted, total_dur, speaker_counts, metadata_rows - - -# ------------------------------------------------------------------ -# Combo 2: VAD only -- extract each segment by timestamps +# Combos 1 & 2: extract segments by timestamps # ------------------------------------------------------------------ -def extract_combo2( +def extract_segments_by_timestamps( entries: list, output_dir: str, output_format: str ) -> tuple[int, float, dict[str, int], list[dict]]: - """Extract VAD segments sorted by start time.""" + """Extract segments by original_start_ms / original_end_ms, sorted by start time.""" by_file = defaultdict(list) for entry in entries: original_file = entry.get("original_file", "") @@ -300,7 +254,7 @@ def extract_combo2( # ------------------------------------------------------------------ -def extract_combo3(entries: list, output_dir: str, output_format: str) -> tuple[int, float, dict, list[dict]]: +def extract_speaker_diar_segments(entries: list, output_dir: str, output_format: str) -> tuple[int, float, dict, list[dict]]: """Extract individual speaking intervals from diar_segments per speaker.""" by_file = defaultdict(list) for entry in entries: @@ -379,7 +333,7 @@ def extract_combo3(entries: list, output_dir: str, output_format: str) -> tuple[ # ------------------------------------------------------------------ -def extract_combo4(entries: list, output_dir: str, output_format: str) -> tuple[int, float, dict, list[dict]]: +def extract_speaker_segments_by_timestamps(entries: list, output_dir: str, output_format: str) -> tuple[int, float, dict, list[dict]]: """Extract speaker-segments using original_start_ms / original_end_ms.""" by_file = defaultdict(list) for entry in entries: @@ -486,18 +440,16 @@ def extract_segments(input_path: str, output_dir: str, output_format: str = DEFA combo = detect_combo(entries) combo_names = { - 1: "Combo 1 (no VAD, no speaker) -- copy full files", - 2: "Combo 2 (VAD only) -- extract VAD segments", - 3: "Combo 3 (speaker only) -- extract diar_segments per speaker", - 4: "Combo 4 (VAD + speaker) -- extract speaker-segments by timestamps", + 2: "Segments by timestamps", + 3: "Speaker diarization segments", + 4: "Speaker-segments by timestamps", } logger.info(f"Detected: {combo_names[combo]}") extractors = { - 1: extract_combo1, - 2: extract_combo2, - 3: extract_combo3, - 4: extract_combo4, + 2: extract_segments_by_timestamps, + 3: extract_speaker_diar_segments, + 4: extract_speaker_segments_by_timestamps, } total_extracted, total_dur, speaker_counts, metadata_rows = extractors[combo](entries, output_dir, output_format) From dbe4813fe40f5a958cce5cc3ab7039260be22c46 Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Wed, 15 Apr 2026 11:27:25 +0000 Subject: [PATCH 03/11] Fix test_decompose_all_disabled_except_mono to expect TimestampMapperStage TimestampMapperStage is the pipeline output normalizer and must always be present. Updated test assertion from 1 to 2 stages and added type check for TimestampMapperStage. --- .../stages/audio/advanced_pipelines/test_audio_data_filter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/stages/audio/advanced_pipelines/test_audio_data_filter.py b/tests/stages/audio/advanced_pipelines/test_audio_data_filter.py index 556e0567d3..2bab453a7e 100644 --- a/tests/stages/audio/advanced_pipelines/test_audio_data_filter.py +++ b/tests/stages/audio/advanced_pipelines/test_audio_data_filter.py @@ -27,6 +27,7 @@ load_config, ) from nemo_curator.stages.audio.filtering import BandFilterStage, SIGMOSFilterStage, UTMOSFilterStage +from nemo_curator.stages.audio.postprocessing import TimestampMapperStage from nemo_curator.stages.audio.preprocessing import MonoConversionStage, SegmentConcatenationStage from nemo_curator.stages.audio.segmentation import SpeakerSeparationStage, VADSegmentationStage @@ -203,8 +204,9 @@ def test_decompose_all_disabled_except_mono(self) -> None: "speaker_separation": {"enable": False}, }) stages = stage.decompose() - assert len(stages) == 1 + assert len(stages) == 2 assert isinstance(stages[0], MonoConversionStage) + assert isinstance(stages[1], TimestampMapperStage) def test_decompose_no_speaker_no_second_pass(self) -> None: stage = AudioDataFilterStage(config={ From bc9eb9a5d3d32b5634bdc0127564130aacc35b07 Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Fri, 17 Apr 2026 11:41:14 -0700 Subject: [PATCH 04/11] Add non-serializable key sanitization and refactor extract_segments - AudioToDocumentStage: strip torch tensors and known non-serializable keys (waveform, audio, segments, etc.) before DataFrame construction as a safety net for upstream stages that miss cleanup. - TimestampMapper: add warning log when no timing info is found and a zero-duration row is emitted. - Add unit tests for _translate_to_original covering exact match, partial overlap, cross-boundary, silence gaps, and malformed mappings. - Refactor extract_segments.py: deduplicate combo extraction functions into a shared _extract_file_segments engine with combo-specific callables for intervals, filenames, and metadata. - Fix README Combo 1 description to match actual output naming. --- nemo_curator/stages/audio/io/convert.py | 38 ++- .../audio/postprocessing/timestamp_mapper.py | 5 + .../postprocessing/test_timestamp_mapper.py | 74 +++++ tutorials/audio/readspeech/README.md | 2 +- .../audio/readspeech/extract_segments.py | 310 +++++++++--------- 5 files changed, 267 insertions(+), 162 deletions(-) diff --git a/nemo_curator/stages/audio/io/convert.py b/nemo_curator/stages/audio/io/convert.py index 4352b61481..4edaa06a7d 100644 --- a/nemo_curator/stages/audio/io/convert.py +++ b/nemo_curator/stages/audio/io/convert.py @@ -13,10 +13,26 @@ # limitations under the License. import pandas as pd +from loguru import logger from nemo_curator.stages.base import ProcessingStage from nemo_curator.tasks import AudioTask, DocumentBatch +_NON_SERIALIZABLE_KEYS = frozenset( + { + "waveform", + "audio", + "audio_data", + "audio_array", + "segments", + } +) + + +def _is_tensor(v: object) -> bool: + """Check if a value is a torch.Tensor without importing torch at module level.""" + return type(v).__name__ == "Tensor" and type(v).__module__.startswith("torch") + class AudioToDocumentStage(ProcessingStage[AudioTask, DocumentBatch]): """Convert AudioTask entries into DocumentBatch DataFrames. @@ -26,6 +42,10 @@ class AudioToDocumentStage(ProcessingStage[AudioTask, DocumentBatch]): avoiding the overhead of many single-row DataFrames. Set ``batch_size`` to control how many audio entries land in each DataFrame (default 64). + + Non-serializable keys (torch tensors, raw audio arrays) are + stripped before building the DataFrame as a safety net, even if + upstream stages failed to clean them up. """ name = "AudioToDocumentStage" @@ -35,10 +55,26 @@ def process(self, task: AudioTask) -> DocumentBatch: msg = "AudioToDocumentStage only supports process_batch" raise NotImplementedError(msg) + @staticmethod + def _sanitize(data: dict) -> dict: + """Remove non-serializable keys and any remaining tensor values.""" + cleaned = {} + for k, v in data.items(): + if k in _NON_SERIALIZABLE_KEYS: + continue + if _is_tensor(v): + logger.warning( + f"[AudioToDocumentStage] Dropping non-serializable " + f"key {k!r} (torch.Tensor) before DataFrame conversion" + ) + continue + cleaned[k] = v + return cleaned + def process_batch(self, tasks: list[AudioTask]) -> list[DocumentBatch]: if len(tasks) == 0: return [] - df = pd.DataFrame([t.data for t in tasks]) + df = pd.DataFrame([self._sanitize(t.data) for t in tasks]) perf = [] for t in tasks: perf.extend(t._stage_perf) diff --git a/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py b/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py index 66b35a8eed..f6e4a99552 100644 --- a/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py +++ b/nemo_curator/stages/audio/postprocessing/timestamp_mapper.py @@ -240,6 +240,11 @@ def _build_output_item_no_mapping(self, item: dict[str, Any]) -> dict[str, Any]: result["duration_ms"] = duration_ms result["duration"] = float(dur) else: + logger.warning( + f"[TimestampMapper] No timing information found for " + f"{result['original_file']!r} — emitting zero-duration row. " + f"This may indicate a corrupted or zero-length source file." + ) result["original_start_ms"] = 0 result["original_end_ms"] = 0 result["duration_ms"] = 0 diff --git a/tests/stages/audio/postprocessing/test_timestamp_mapper.py b/tests/stages/audio/postprocessing/test_timestamp_mapper.py index 47e6af96fd..fa64578dd6 100644 --- a/tests/stages/audio/postprocessing/test_timestamp_mapper.py +++ b/tests/stages/audio/postprocessing/test_timestamp_mapper.py @@ -17,6 +17,7 @@ from nemo_curator.stages.audio.postprocessing.timestamp_mapper import ( _NEVER_PASS_KEYS, TimestampMapperStage, + _translate_to_original, ) from nemo_curator.tasks import AudioTask @@ -28,6 +29,79 @@ def _make_task(data: dict, task_id: str = "test", metadata: dict | None = None) return t +class TestTranslateToOriginal: + """Unit tests for the pure _translate_to_original() function.""" + + MAPPINGS = [ + {"concat_start_ms": 0, "concat_end_ms": 2000, "original_file": "a.wav", "original_start_ms": 5000, "original_end_ms": 7000}, + {"concat_start_ms": 2000, "concat_end_ms": 5000, "original_file": "b.wav", "original_start_ms": 0, "original_end_ms": 3000}, + {"concat_start_ms": 5000, "concat_end_ms": 8000, "original_file": "c.wav", "original_start_ms": 10000, "original_end_ms": 13000}, + ] + + def test_single_mapping_exact_match(self) -> None: + """Segment exactly matches one mapping.""" + results = _translate_to_original(self.MAPPINGS, 0, 2000) + assert len(results) == 1 + assert results[0]["original_file"] == "a.wav" + assert results[0]["original_start_ms"] == 5000 + assert results[0]["original_end_ms"] == 7000 + assert results[0]["duration_ms"] == 2000 + + def test_single_mapping_partial_overlap(self) -> None: + """Segment partially overlaps one mapping.""" + results = _translate_to_original(self.MAPPINGS, 500, 1500) + assert len(results) == 1 + assert results[0]["original_file"] == "a.wav" + assert results[0]["original_start_ms"] == 5500 + assert results[0]["original_end_ms"] == 6500 + assert results[0]["duration_ms"] == 1000 + + def test_cross_boundary_span(self) -> None: + """Segment spans two mappings — returns both.""" + results = _translate_to_original(self.MAPPINGS, 1500, 3000) + assert len(results) == 2 + assert results[0]["original_file"] == "a.wav" + assert results[0]["original_start_ms"] == 6500 + assert results[0]["original_end_ms"] == 7000 + assert results[0]["duration_ms"] == 500 + assert results[1]["original_file"] == "b.wav" + assert results[1]["original_start_ms"] == 0 + assert results[1]["original_end_ms"] == 1000 + assert results[1]["duration_ms"] == 1000 + + def test_silence_gap_no_overlap(self) -> None: + """Segment falls entirely in a gap between mappings.""" + mappings = [ + {"concat_start_ms": 0, "concat_end_ms": 1000, "original_file": "a.wav", "original_start_ms": 0, "original_end_ms": 1000}, + {"concat_start_ms": 3000, "concat_end_ms": 5000, "original_file": "b.wav", "original_start_ms": 0, "original_end_ms": 2000}, + ] + results = _translate_to_original(mappings, 1000, 3000) + assert len(results) == 0 + + def test_malformed_mapping_missing_key(self) -> None: + """Malformed mapping (missing key) is skipped gracefully.""" + mappings = [ + {"concat_start_ms": 0, "concat_end_ms": 2000}, + {"concat_start_ms": 2000, "concat_end_ms": 4000, "original_file": "b.wav", "original_start_ms": 0, "original_end_ms": 2000}, + ] + results = _translate_to_original(mappings, 0, 4000) + assert len(results) == 1 + assert results[0]["original_file"] == "b.wav" + + def test_empty_mappings(self) -> None: + """Empty mappings list returns empty results.""" + results = _translate_to_original([], 0, 1000) + assert results == [] + + def test_no_overlap_before_all_mappings(self) -> None: + """Segment ends before any mapping starts.""" + mappings = [ + {"concat_start_ms": 5000, "concat_end_ms": 8000, "original_file": "a.wav", "original_start_ms": 0, "original_end_ms": 3000}, + ] + results = _translate_to_original(mappings, 0, 1000) + assert results == [] + + def test_combo4_with_segment_mappings() -> None: """Full pipeline: remaps concat-space timestamps to original file positions.""" mappings = [ diff --git a/tutorials/audio/readspeech/README.md b/tutorials/audio/readspeech/README.md index f48e276fdf..14aae38d78 100644 --- a/tutorials/audio/readspeech/README.md +++ b/tutorials/audio/readspeech/README.md @@ -273,7 +273,7 @@ python extract_segments.py -m ./dns_data/result/ -o ./extracted/ -f flac | Topology | What it extracts | File naming | |----------|-----------------|-------------| -| Combo 1 | Full file copy | `{name}.wav` | +| Combo 1 | Full file (single segment) | `{name}_segment_000.wav` | | Combo 2 | Each VAD segment | `{name}_segment_000.wav` | | Combo 3 | Each speaking interval per speaker | `{name}_speaker_0_segment_000.wav` | | Combo 4 | Each speaker-segment | `{name}_speaker_0_segment_000.wav` | diff --git a/tutorials/audio/readspeech/extract_segments.py b/tutorials/audio/readspeech/extract_segments.py index c2260fe2db..96c2212d1e 100755 --- a/tutorials/audio/readspeech/extract_segments.py +++ b/tutorials/audio/readspeech/extract_segments.py @@ -58,6 +58,7 @@ import sys from collections import defaultdict from pathlib import Path +from typing import Any, Callable import numpy as np import soundfile as sf @@ -188,117 +189,64 @@ def _read_segment(filepath: str, start_ms: int, end_ms: int, sample_rate: int) - # ------------------------------------------------------------------ -# Combos 1 & 2: extract segments by timestamps +# Shared extraction engine # ------------------------------------------------------------------ +Interval = tuple[int, int, float] # (start_ms, end_ms, duration_sec) -def extract_segments_by_timestamps( - entries: list, output_dir: str, output_format: str -) -> tuple[int, float, dict[str, int], list[dict]]: - """Extract segments by original_start_ms / original_end_ms, sorted by start time.""" - by_file = defaultdict(list) - for entry in entries: - original_file = entry.get("original_file", "") - by_file[original_file].append(entry) - - extracted = 0 - total_dur = 0.0 - speaker_counts: dict[str, int] = {} - metadata_rows: list[dict] = [] - - for original_file, segments in by_file.items(): - if not os.path.exists(original_file): - logger.error(f"Original file not found: {original_file}") - continue - - info = sf.info(original_file) - original_name = Path(original_file).stem - logger.info(f"\nProcessing: {original_name} ({len(segments)} segments)") - - segments.sort(key=lambda x: x.get("original_start_ms", 0)) - - for i, seg in enumerate(segments): - start_ms = seg.get("original_start_ms", 0) - end_ms = seg.get("original_end_ms", 0) - dur = seg.get("duration", (end_ms - start_ms) / 1000) - - out_filename = f"{original_name}_segment_{i:03d}.{output_format}" - output_path = os.path.join(output_dir, out_filename) - - try: - audio = _read_segment(original_file, start_ms, end_ms, info.samplerate) - _write_segment(output_path, audio, info.samplerate, output_format) - extracted += 1 - total_dur += dur - logger.debug(f" {out_filename} ({start_ms}-{end_ms}ms, {dur:.2f}s)") - - metadata_rows.append( - { - "filename": out_filename, - "original_file": original_file, - "segment_index": i, - "start_sec": round(start_ms / 1000, 3), - "end_sec": round(end_ms / 1000, 3), - "duration": round(dur, 3), - **_extract_scores(seg), - } - ) - except Exception as e: # noqa: BLE001 - logger.error(f" Failed to extract {out_filename}: {e}") - - return extracted, total_dur, speaker_counts, metadata_rows - -# ------------------------------------------------------------------ -# Combo 3: speaker only -- extract each diar_segment per speaker -# ------------------------------------------------------------------ +def _get_speaker_label(entry: dict) -> tuple[str, str]: + """Return (speaker_id, speaker_num) from a manifest entry.""" + speaker_id = entry.get("speaker_id", "unknown") + speaker_num = speaker_id.replace("speaker_", "") if "speaker_" in speaker_id else speaker_id + return speaker_id, speaker_num -def extract_speaker_diar_segments(entries: list, output_dir: str, output_format: str) -> tuple[int, float, dict, list[dict]]: - """Extract individual speaking intervals from diar_segments per speaker.""" - by_file = defaultdict(list) +def _extract_file_segments( + entries: list, + output_dir: str, + output_format: str, + *, + sort_key: Callable[[dict], Any], + get_intervals: Callable[[dict], list[Interval]], + make_filename: Callable[[str, dict, int], str], + make_metadata: Callable[[str, str, dict, int, int, int, float], dict], +) -> tuple[int, float, dict[str, int], list[dict]]: + """Shared group-by-file -> read -> write -> metadata loop. + + Args: + entries: Manifest entries to extract. + output_dir: Where to write extracted audio files. + output_format: Audio format (wav, flac, ogg). + sort_key: How to sort entries within each file group. + get_intervals: Given an entry, return a list of (start_ms, end_ms, dur_sec). + make_filename: Given (original_name, entry, segment_index), return filename. + make_metadata: Given (filename, original_file, entry, seg_idx, start_ms, + end_ms, dur), return the metadata dict for this segment. + """ + by_file: dict[str, list] = defaultdict(list) for entry in entries: - original_file = entry.get("original_file", "") - by_file[original_file].append(entry) + by_file[entry.get("original_file", "")].append(entry) extracted = 0 total_dur = 0.0 speaker_counts: dict[str, int] = defaultdict(int) metadata_rows: list[dict] = [] - for original_file, speaker_entries in by_file.items(): + for original_file, file_entries in by_file.items(): if not os.path.exists(original_file): logger.error(f"Original file not found: {original_file}") continue info = sf.info(original_file) original_name = Path(original_file).stem - logger.info(f"\nProcessing: {original_name} ({len(speaker_entries)} speakers)") + file_entries.sort(key=sort_key) + logger.info(f"\nProcessing: {original_name} ({len(file_entries)} entries)") - speaker_entries.sort(key=lambda x: x.get("speaker_id", "")) - - for entry in speaker_entries: - speaker_id = entry.get("speaker_id", "unknown") - speaker_num = speaker_id.replace("speaker_", "") if "speaker_" in speaker_id else speaker_id - num_speakers = entry.get("num_speakers", 0) - diar_segments = entry.get("diar_segments", []) - - scores = _extract_scores(entry) - - if not diar_segments: - logger.warning(f" {speaker_id}: no diar_segments, skipping") - continue - - diar_segments_sorted = sorted(diar_segments, key=lambda x: x[0]) - - logger.info(f" {speaker_id}: {len(diar_segments_sorted)} speaking intervals") - - for j, (start_sec, end_sec) in enumerate(diar_segments_sorted): - start_ms = int(start_sec * 1000) - end_ms = int(end_sec * 1000) - dur = end_sec - start_sec - - out_filename = f"{original_name}_speaker_{speaker_num}_segment_{j:03d}.{output_format}" + for entry in file_entries: + intervals = get_intervals(entry) + for seg_idx, (start_ms, end_ms, dur) in enumerate(intervals): + out_filename = make_filename(original_name, entry, seg_idx) output_path = os.path.join(output_dir, out_filename) try: @@ -306,97 +254,139 @@ def extract_speaker_diar_segments(entries: list, output_dir: str, output_format: _write_segment(output_path, audio, info.samplerate, output_format) extracted += 1 total_dur += dur - speaker_counts[speaker_id] += 1 - logger.debug(f" {out_filename} ({start_sec:.2f}-{end_sec:.2f}s, {dur:.2f}s)") + + speaker_id = entry.get("speaker_id") + if speaker_id: + speaker_counts[speaker_id] += 1 metadata_rows.append( - { - "filename": out_filename, - "original_file": original_file, - "speaker_id": speaker_id, - "num_speakers": num_speakers, - "segment_index": j, - "start_sec": round(start_sec, 3), - "end_sec": round(end_sec, 3), - "duration": round(dur, 3), - **scores, - } + make_metadata(out_filename, original_file, entry, seg_idx, start_ms, end_ms, dur) ) + logger.debug(f" {out_filename} ({start_ms}-{end_ms}ms, {dur:.2f}s)") except Exception as e: # noqa: BLE001 - logger.error(f" Failed to extract {out_filename}: {e}") + logger.error(f" Failed to extract {out_filename}: {e}") return extracted, total_dur, speaker_counts, metadata_rows # ------------------------------------------------------------------ -# Combo 4: VAD + speaker -- extract each speaker-segment by timestamps +# Combo-specific callables # ------------------------------------------------------------------ -def extract_speaker_segments_by_timestamps(entries: list, output_dir: str, output_format: str) -> tuple[int, float, dict, list[dict]]: - """Extract speaker-segments using original_start_ms / original_end_ms.""" - by_file = defaultdict(list) - for entry in entries: - original_file = entry.get("original_file", "") - by_file[original_file].append(entry) +def _intervals_from_timestamps(entry: dict) -> list[Interval]: + start_ms = entry.get("original_start_ms", 0) + end_ms = entry.get("original_end_ms", 0) + dur = entry.get("duration", (end_ms - start_ms) / 1000) + return [(start_ms, end_ms, dur)] - extracted = 0 - total_dur = 0.0 - speaker_counts: dict[str, int] = defaultdict(int) - metadata_rows: list[dict] = [] - for original_file, segments in by_file.items(): - if not os.path.exists(original_file): - logger.error(f"Original file not found: {original_file}") - continue +def _intervals_from_diar_segments(entry: dict) -> list[Interval]: + diar_segments = entry.get("diar_segments", []) + if not diar_segments: + speaker_id = entry.get("speaker_id", "unknown") + logger.warning(f" {speaker_id}: no diar_segments, skipping") + return [] + return [ + (int(s * 1000), int(e * 1000), e - s) + for s, e in sorted(diar_segments, key=lambda x: x[0]) + ] + + +def _base_metadata( + filename: str, original_file: str, entry: dict, + seg_idx: int, start_ms: int, end_ms: int, dur: float, +) -> dict: + row: dict = { + "filename": filename, + "original_file": original_file, + "segment_index": seg_idx, + "start_sec": round(start_ms / 1000, 3), + "end_sec": round(end_ms / 1000, 3), + "duration": round(dur, 3), + } + speaker_id = entry.get("speaker_id") + if speaker_id is not None: + row["speaker_id"] = speaker_id + num_speakers = entry.get("num_speakers") + if num_speakers is not None: + row["num_speakers"] = num_speakers + row.update(_extract_scores(entry)) + return row - info = sf.info(original_file) - original_name = Path(original_file).stem - logger.info(f"\nProcessing: {original_name} ({len(segments)} speaker-segments)") - segments.sort(key=lambda x: (x.get("speaker_id", ""), x.get("original_start_ms", 0))) +# ------------------------------------------------------------------ +# Combos 1 & 2: extract segments by timestamps +# ------------------------------------------------------------------ + + +def extract_segments_by_timestamps( + entries: list, output_dir: str, output_format: str, +) -> tuple[int, float, dict[str, int], list[dict]]: + """Extract segments by original_start_ms / original_end_ms, sorted by start time.""" + counter: dict[str, int] = defaultdict(int) + + def _make_filename(name: str, entry: dict, _seg_idx: int) -> str: + idx = counter[name] + counter[name] += 1 + return f"{name}_segment_{idx:03d}.{output_format}" + + return _extract_file_segments( + entries, output_dir, output_format, + sort_key=lambda x: x.get("original_start_ms", 0), + get_intervals=_intervals_from_timestamps, + make_filename=_make_filename, + make_metadata=_base_metadata, + ) + - per_speaker_count: dict[str, int] = defaultdict(int) +# ------------------------------------------------------------------ +# Combo 3: speaker only -- extract each diar_segment per speaker +# ------------------------------------------------------------------ - for seg in segments: - speaker_id = seg.get("speaker_id", "unknown") - speaker_num = speaker_id.replace("speaker_", "") if "speaker_" in speaker_id else speaker_id - num_speakers = seg.get("num_speakers", 0) - start_ms = seg.get("original_start_ms", 0) - end_ms = seg.get("original_end_ms", 0) - dur = seg.get("duration", (end_ms - start_ms) / 1000) - seg_idx = per_speaker_count[speaker_id] - per_speaker_count[speaker_id] += 1 +def extract_speaker_diar_segments( + entries: list, output_dir: str, output_format: str, +) -> tuple[int, float, dict[str, int], list[dict]]: + """Extract individual speaking intervals from diar_segments per speaker.""" - out_filename = f"{original_name}_speaker_{speaker_num}_segment_{seg_idx:03d}.{output_format}" - output_path = os.path.join(output_dir, out_filename) + def _make_filename(name: str, entry: dict, seg_idx: int) -> str: + _, speaker_num = _get_speaker_label(entry) + return f"{name}_speaker_{speaker_num}_segment_{seg_idx:03d}.{output_format}" - try: - audio = _read_segment(original_file, start_ms, end_ms, info.samplerate) - _write_segment(output_path, audio, info.samplerate, output_format) - extracted += 1 - total_dur += dur - speaker_counts[speaker_id] += 1 - logger.debug(f" {out_filename} ({start_ms}-{end_ms}ms, {dur:.2f}s)") - - metadata_rows.append( - { - "filename": out_filename, - "original_file": original_file, - "speaker_id": speaker_id, - "num_speakers": num_speakers, - "segment_index": seg_idx, - "start_sec": round(start_ms / 1000, 3), - "end_sec": round(end_ms / 1000, 3), - "duration": round(dur, 3), - **_extract_scores(seg), - } - ) - except Exception as e: # noqa: BLE001 - logger.error(f" Failed to extract {out_filename}: {e}") + return _extract_file_segments( + entries, output_dir, output_format, + sort_key=lambda x: x.get("speaker_id", ""), + get_intervals=_intervals_from_diar_segments, + make_filename=_make_filename, + make_metadata=_base_metadata, + ) - return extracted, total_dur, speaker_counts, metadata_rows + +# ------------------------------------------------------------------ +# Combo 4: VAD + speaker -- extract each speaker-segment by timestamps +# ------------------------------------------------------------------ + + +def extract_speaker_segments_by_timestamps( + entries: list, output_dir: str, output_format: str, +) -> tuple[int, float, dict[str, int], list[dict]]: + """Extract speaker-segments using original_start_ms / original_end_ms.""" + per_speaker_count: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + + def _make_filename(name: str, entry: dict, _seg_idx: int) -> str: + speaker_id, speaker_num = _get_speaker_label(entry) + idx = per_speaker_count[name][speaker_id] + per_speaker_count[name][speaker_id] += 1 + return f"{name}_speaker_{speaker_num}_segment_{idx:03d}.{output_format}" + + return _extract_file_segments( + entries, output_dir, output_format, + sort_key=lambda x: (x.get("speaker_id", ""), x.get("original_start_ms", 0)), + get_intervals=_intervals_from_timestamps, + make_filename=_make_filename, + make_metadata=_base_metadata, + ) # ------------------------------------------------------------------ From 0f9319d0ca804273eeb2bb2e1bcee53227cb628a Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Fri, 17 Apr 2026 11:50:20 -0700 Subject: [PATCH 05/11] Fix ruff lint errors in test_timestamp_mapper and extract_segments --- .../stages/audio/postprocessing/test_timestamp_mapper.py | 4 +++- tutorials/audio/readspeech/extract_segments.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/stages/audio/postprocessing/test_timestamp_mapper.py b/tests/stages/audio/postprocessing/test_timestamp_mapper.py index fa64578dd6..74a077d8a9 100644 --- a/tests/stages/audio/postprocessing/test_timestamp_mapper.py +++ b/tests/stages/audio/postprocessing/test_timestamp_mapper.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import ClassVar + import torch from nemo_curator.stages.audio.postprocessing.timestamp_mapper import ( @@ -32,7 +34,7 @@ def _make_task(data: dict, task_id: str = "test", metadata: dict | None = None) class TestTranslateToOriginal: """Unit tests for the pure _translate_to_original() function.""" - MAPPINGS = [ + MAPPINGS: ClassVar[list[dict]] = [ {"concat_start_ms": 0, "concat_end_ms": 2000, "original_file": "a.wav", "original_start_ms": 5000, "original_end_ms": 7000}, {"concat_start_ms": 2000, "concat_end_ms": 5000, "original_file": "b.wav", "original_start_ms": 0, "original_end_ms": 3000}, {"concat_start_ms": 5000, "concat_end_ms": 8000, "original_file": "c.wav", "original_start_ms": 10000, "original_end_ms": 13000}, diff --git a/tutorials/audio/readspeech/extract_segments.py b/tutorials/audio/readspeech/extract_segments.py index 96c2212d1e..3758f672f7 100755 --- a/tutorials/audio/readspeech/extract_segments.py +++ b/tutorials/audio/readspeech/extract_segments.py @@ -57,8 +57,9 @@ import os import sys from collections import defaultdict +from collections.abc import Callable from pathlib import Path -from typing import Any, Callable +from typing import Any import numpy as np import soundfile as sf @@ -202,7 +203,7 @@ def _get_speaker_label(entry: dict) -> tuple[str, str]: return speaker_id, speaker_num -def _extract_file_segments( +def _extract_file_segments( # noqa: PLR0913 entries: list, output_dir: str, output_format: str, @@ -293,7 +294,7 @@ def _intervals_from_diar_segments(entry: dict) -> list[Interval]: ] -def _base_metadata( +def _base_metadata( # noqa: PLR0913 filename: str, original_file: str, entry: dict, seg_idx: int, start_ms: int, end_ms: int, dur: float, ) -> dict: @@ -326,7 +327,7 @@ def extract_segments_by_timestamps( """Extract segments by original_start_ms / original_end_ms, sorted by start time.""" counter: dict[str, int] = defaultdict(int) - def _make_filename(name: str, entry: dict, _seg_idx: int) -> str: + def _make_filename(name: str, _entry: dict, _seg_idx: int) -> str: idx = counter[name] counter[name] += 1 return f"{name}_segment_{idx:03d}.{output_format}" From 55a389cd674854f0b9bab8cdf840abcfba289728 Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Fri, 17 Apr 2026 12:47:55 -0700 Subject: [PATCH 06/11] Address PR review: SpeakerResult NamedTuple, docs, passthrough_keys - Replace fragile tuple return type in get_speaker_audio_data with SpeakerResult NamedTuple for self-documenting named field access. - Update caller in speaker_separation.py to use result.audio, result.duration, result.diar_segments instead of positional unpacking. - Document _INHERITED_DROP_KEYS explaining why duration/num_samples are dropped from parent tasks during speaker separation. - Add comments in pipeline.py and pipeline.yaml documenting the default passthrough_keys and how to restrict output columns. - Add speaker separation note to README explaining dropped keys. --- .../audio/segmentation/speaker_separation.py | 16 ++++++++++------ .../speaker_separation_module/speaker_sep.py | 13 +++++++++++-- tutorials/audio/readspeech/README.md | 7 +++++++ tutorials/audio/readspeech/pipeline.py | 3 +++ tutorials/audio/readspeech/pipeline.yaml | 7 +++++++ 5 files changed, 38 insertions(+), 8 deletions(-) diff --git a/nemo_curator/stages/audio/segmentation/speaker_separation.py b/nemo_curator/stages/audio/segmentation/speaker_separation.py index 858bce831b..5ffe71d779 100755 --- a/nemo_curator/stages/audio/segmentation/speaker_separation.py +++ b/nemo_curator/stages/audio/segmentation/speaker_separation.py @@ -158,6 +158,10 @@ def _initialize_separator(self) -> None: logger.error(f"Failed to load speaker separator: {e}") raise + # Keys dropped from the parent task when building per-speaker child tasks. + # "audio"/"waveform" are non-serializable blobs replaced by per-speaker audio. + # "duration"/"num_samples" describe the parent file, not the speaker segment; + # each child gets its own duration from the diarization result. _INHERITED_DROP_KEYS = frozenset({"audio", "waveform", "duration", "num_samples"}) def _build_speaker_tasks( @@ -169,19 +173,19 @@ def _build_speaker_tasks( """Build AudioTask list from speaker audio data.""" results: list[AudioTask] = [] num_speakers = len(speaker_audio_data) - for speaker_id, (speaker_audio_pydub, duration, diar_segments) in speaker_audio_data.items(): - if duration < self.min_duration: - logger.debug(f"Skipping {speaker_id}: duration {duration:.2f}s < {self.min_duration}s") + for speaker_id, result in speaker_audio_data.items(): + if result.duration < self.min_duration: + logger.debug(f"Skipping {speaker_id}: duration {result.duration:.2f}s < {self.min_duration}s") continue - spk_waveform, spk_sr = _pydub_to_waveform_sr(speaker_audio_pydub) + spk_waveform, spk_sr = _pydub_to_waveform_sr(result.audio) speaker_data = { **{k: v for k, v in item.items() if k not in self._INHERITED_DROP_KEYS}, "waveform": spk_waveform, "sample_rate": spk_sr, "speaker_id": speaker_id, "num_speakers": num_speakers, - "duration": duration, - "diar_segments": diar_segments, + "duration": result.duration, + "diar_segments": result.diar_segments, } spk_task = AudioTask( data=speaker_data, diff --git a/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py b/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py index 715d6452f8..aa0426abcd 100755 --- a/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py +++ b/nemo_curator/stages/audio/segmentation/speaker_separation_module/speaker_sep.py @@ -14,12 +14,21 @@ import os import tempfile +from typing import NamedTuple import soundfile as sf import torch from loguru import logger from pydub import AudioSegment + +class SpeakerResult(NamedTuple): + """Result for a single speaker from get_speaker_audio_data.""" + + audio: AudioSegment + duration: float + diar_segments: list[tuple[float, float]] + try: from nemo.collections.asr.models import SortformerEncLabelModel except ImportError: @@ -456,7 +465,7 @@ def get_speaker_audio_data( # noqa: PLR0913, C901, PLR0912 exclude_overlaps: bool | None = None, min_duration: float | None = None, buffer_time: float | None = None, - ) -> dict[str, tuple[AudioSegment, float, list[tuple[float, float]]]]: + ) -> dict[str, SpeakerResult]: """ Process an audio file or waveform and return AudioSegment objects for each speaker. """ @@ -521,7 +530,7 @@ def get_speaker_audio_data( # noqa: PLR0913, C901, PLR0912 if silent_audio.rms < 1: continue - speaker_audio[speaker] = (silent_audio, total_duration, segments) + speaker_audio[speaker] = SpeakerResult(silent_audio, total_duration, segments) # Free the original audio to release memory before returning del original_audio diff --git a/tutorials/audio/readspeech/README.md b/tutorials/audio/readspeech/README.md index 14aae38d78..90bc9944c1 100644 --- a/tutorials/audio/readspeech/README.md +++ b/tutorials/audio/readspeech/README.md @@ -230,6 +230,13 @@ AudioDataFilterStage(config={ are always blocked, even if added to `passthrough_keys`. A warning is logged if blocked keys are detected in the configuration. +**Speaker separation note**: When speaker separation is enabled, the parent +task's `duration` and `num_samples` fields are dropped before building +per-speaker child tasks, since each speaker segment has its own duration +computed from the diarization result. Only `audio`/`waveform` (non-serializable) +and `duration`/`num_samples` (parent-specific) are dropped; all other fields +are inherited by child tasks. + ### Example outputs **Combo 1** (no VAD, no speaker): diff --git a/tutorials/audio/readspeech/pipeline.py b/tutorials/audio/readspeech/pipeline.py index de25ddaa3a..9117ff20c1 100644 --- a/tutorials/audio/readspeech/pipeline.py +++ b/tutorials/audio/readspeech/pipeline.py @@ -122,6 +122,9 @@ def create_readspeech_pipeline(args: argparse.Namespace) -> Pipeline: "exclude_overlaps": args.speaker_exclude_overlaps, "min_duration": args.speaker_min_duration, }, + # Empty dict uses _DEFAULT_PASSTHROUGH_KEYS (all 13 built-in + # filter/speaker keys). To restrict output columns, set e.g.: + # "passthrough_keys": ["utmos_mos", "sigmos_noise", "sigmos_ovrl"] "timestamp_mapper": {}, } ) diff --git a/tutorials/audio/readspeech/pipeline.yaml b/tutorials/audio/readspeech/pipeline.yaml index bbad9e2ad6..fbc47d02c9 100644 --- a/tutorials/audio/readspeech/pipeline.yaml +++ b/tutorials/audio/readspeech/pipeline.yaml @@ -92,6 +92,13 @@ processors: enable: ${enable_speaker_separation} exclude_overlaps: ${speaker_exclude_overlaps} min_duration: ${speaker_min_duration} + # Empty config uses _DEFAULT_PASSTHROUGH_KEYS which includes: + # speaker_id, num_speakers, speaking_duration, sample_rate, + # utmos_mos, sigmos_noise, sigmos_ovrl, sigmos_sig, sigmos_col, + # sigmos_disc, sigmos_loud, sigmos_reverb, band_prediction. + # To restrict output columns, set passthrough_keys explicitly: + # timestamp_mapper: + # passthrough_keys: [utmos_mos, sigmos_noise, sigmos_ovrl] timestamp_mapper: {} - _target_: nemo_curator.stages.audio.io.convert.AudioToDocumentStage From b5a43d21b9afa4f513b61d21df874b7c4e1812aa Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Fri, 17 Apr 2026 13:16:15 -0700 Subject: [PATCH 07/11] Fix test mocks to use SpeakerResult NamedTuple Update test_speaker_separation.py to construct SpeakerResult instead of plain tuples, matching the named attribute access introduced in speaker_separation.py. --- .../audio/segmentation/test_speaker_separation.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/tests/stages/audio/segmentation/test_speaker_separation.py b/tests/stages/audio/segmentation/test_speaker_separation.py index ea070d64b0..43067eb5e2 100644 --- a/tests/stages/audio/segmentation/test_speaker_separation.py +++ b/tests/stages/audio/segmentation/test_speaker_separation.py @@ -19,7 +19,10 @@ from pydub import AudioSegment from nemo_curator.stages.audio.segmentation.speaker_separation import SpeakerSeparationStage -from nemo_curator.stages.audio.segmentation.speaker_separation_module.speaker_sep import SpeakerSeparator +from nemo_curator.stages.audio.segmentation.speaker_separation_module.speaker_sep import ( + SpeakerResult, + SpeakerSeparator, +) from nemo_curator.tasks import AudioTask @@ -43,8 +46,8 @@ def test_process_returns_per_speaker_tasks(self, mock_init: MagicMock) -> None: separator = MagicMock() speaker_data = { - "speaker_0": (_make_audio_segment(3000), 3.0, [(0.0, 3.0)]), - "speaker_1": (_make_audio_segment(4000), 4.0, [(0.0, 4.0)]), + "speaker_0": SpeakerResult(_make_audio_segment(3000), 3.0, [(0.0, 3.0)]), + "speaker_1": SpeakerResult(_make_audio_segment(4000), 4.0, [(0.0, 4.0)]), } separator.get_speaker_audio_data.return_value = speaker_data stage._separator = separator @@ -66,7 +69,7 @@ def test_process_output_keys(self, mock_init: MagicMock) -> None: separator = MagicMock() separator.get_speaker_audio_data.return_value = { - "spk_0": (_make_audio_segment(5000), 5.0, [(0.0, 5.0)]), + "spk_0": SpeakerResult(_make_audio_segment(5000), 5.0, [(0.0, 5.0)]), } stage._separator = separator @@ -86,8 +89,8 @@ def test_min_duration_filters_short_speakers(self, mock_init: MagicMock) -> None separator = MagicMock() separator.get_speaker_audio_data.return_value = { - "speaker_0": (_make_audio_segment(5000), 5.0, [(0.0, 5.0)]), - "speaker_1": (_make_audio_segment(1000), 1.0, [(0.0, 1.0)]), + "speaker_0": SpeakerResult(_make_audio_segment(5000), 5.0, [(0.0, 5.0)]), + "speaker_1": SpeakerResult(_make_audio_segment(1000), 1.0, [(0.0, 1.0)]), } stage._separator = separator From 94480421221c3152b0ab65c3e04439f360d84fca Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Mon, 20 Apr 2026 05:26:20 -0700 Subject: [PATCH 08/11] refactor: move extract_segments logic into SegmentExtractionStage Move reusable segment extraction functions from the tutorial script into the nemo_curator package as a proper ProcessingStage. The tutorial becomes a thin CLI wrapper importing from the package. - Add nemo_curator/stages/audio/io/extract_segments.py with SegmentExtractionStage (ProcessingStage[AudioTask, AudioTask]) - Slim tutorials/audio/readspeech/extract_segments.py to CLI wrapper - Add tests/stages/audio/io/test_extract_segments.py with clean imports --- .../stages/audio/io/extract_segments.py | 565 ++++++++++++++++++ .../stages/audio/io/test_extract_segments.py | 417 +++++++++++++ .../audio/readspeech/extract_segments.py | 459 +------------- 3 files changed, 995 insertions(+), 446 deletions(-) create mode 100644 nemo_curator/stages/audio/io/extract_segments.py create mode 100644 tests/stages/audio/io/test_extract_segments.py diff --git a/nemo_curator/stages/audio/io/extract_segments.py b/nemo_curator/stages/audio/io/extract_segments.py new file mode 100644 index 0000000000..cdf16e0171 --- /dev/null +++ b/nemo_curator/stages/audio/io/extract_segments.py @@ -0,0 +1,565 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Audio segment extraction stage. + +Extracts audio segments from original source files based on manifest +entries produced by NeMo Curator audio pipelines. Auto-detects the +pipeline combo from the manifest schema and applies the appropriate +extraction strategy: + + Combo 2 (no VAD / VAD only): + Extracts each segment by ``original_start_ms`` / ``original_end_ms``. + Output: ``{original_filename}_segment_{NNN}.{format}`` + + Combo 3 (speaker diarization): + Extracts each speaking interval from ``diar_segments`` per speaker. + Output: ``{original_filename}_speaker_{X}_segment_{NNN}.{format}`` + + Combo 4 (VAD + speaker): + Extracts each speaker-segment by timestamps. + Output: ``{original_filename}_speaker_{X}_segment_{NNN}.{format}`` + +Example: + from nemo_curator.stages.audio.io.extract_segments import SegmentExtractionStage + + stage = SegmentExtractionStage( + output_dir="/data/extracted", + output_format="flac", + ) + + # Standalone usage (post-pipeline): + stage.extract_from_manifest("manifest.jsonl") + + # Or as a pipeline stage: + pipeline.add_stage(stage) +""" + +from __future__ import annotations + +import csv +import glob +import json +import os +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import soundfile as sf +from loguru import logger + +if TYPE_CHECKING: + from collections.abc import Callable + + import numpy as np + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + +DEFAULT_OUTPUT_FORMAT = "wav" + +SOUNDFILE_FORMATS = { + "wav": "PCM_16", + "flac": "PCM_16", + "ogg": "VORBIS", +} + +_CSV_STRUCTURAL_KEYS = frozenset( + { + "filename", + "original_file", + "original_start_ms", + "original_end_ms", + "duration_ms", + "start_sec", + "end_sec", + "duration", + "segment_index", + "speaker_id", + "num_speakers", + "speaking_duration", + "diar_segments", + } +) + +Interval = tuple[int, int, float] # (start_ms, end_ms, duration_sec) + + +# ------------------------------------------------------------------ +# Pure helper functions +# ------------------------------------------------------------------ + + +def _extract_scores(entry: dict) -> dict: + """Extract quality/filter score fields from a manifest entry. + + Returns all keys that are not structural CSV columns (timestamps, + duration, speaker info), with float values rounded for readability. + Since TimestampMapper already whitelist-filters the manifest output, + anything remaining is a quality score or user-defined field. + """ + return {k: round(v, 4) if isinstance(v, float) else v for k, v in entry.items() if k not in _CSV_STRUCTURAL_KEYS} + + +def _get_speaker_label(entry: dict) -> tuple[str, str]: + """Return (speaker_id, speaker_num) from a manifest entry.""" + speaker_id = entry.get("speaker_id", "unknown") + speaker_num = speaker_id.replace("speaker_", "") if "speaker_" in speaker_id else speaker_id + return speaker_id, speaker_num + + +def _read_segment(filepath: str, start_ms: int, end_ms: int, sample_rate: int) -> np.ndarray: + """Read a slice of audio from a file.""" + start_sample = int(start_ms * sample_rate / 1000) + end_sample = int(end_ms * sample_rate / 1000) + audio, _ = sf.read(filepath, start=start_sample, stop=end_sample, dtype="float32") + return audio + + +def _intervals_from_timestamps(entry: dict) -> list[Interval]: + start_ms = entry.get("original_start_ms", 0) + end_ms = entry.get("original_end_ms", 0) + dur = entry.get("duration", (end_ms - start_ms) / 1000) + return [(start_ms, end_ms, dur)] + + +def _intervals_from_diar_segments(entry: dict) -> list[Interval]: + diar_segments = entry.get("diar_segments", []) + if not diar_segments: + speaker_id = entry.get("speaker_id", "unknown") + logger.warning(f" {speaker_id}: no diar_segments, skipping") + return [] + return [ + (int(s * 1000), int(e * 1000), e - s) + for s, e in sorted(diar_segments, key=lambda x: x[0]) + ] + + +def _base_metadata( # noqa: PLR0913 + filename: str, original_file: str, entry: dict, + seg_idx: int, start_ms: int, end_ms: int, dur: float, +) -> dict: + row: dict = { + "filename": filename, + "original_file": original_file, + "segment_index": seg_idx, + "start_sec": round(start_ms / 1000, 3), + "end_sec": round(end_ms / 1000, 3), + "duration": round(dur, 3), + } + speaker_id = entry.get("speaker_id") + if speaker_id is not None: + row["speaker_id"] = speaker_id + num_speakers = entry.get("num_speakers") + if num_speakers is not None: + row["num_speakers"] = num_speakers + row.update(_extract_scores(entry)) + return row + + +def detect_combo(entries: list) -> int: + """Detect which pipeline combo produced the manifest. + + Returns 2, 3, or 4. Since TimestampMapper always emits + ``original_start_ms``/``original_end_ms``, combos 1 and 2 are + indistinguishable and both use timestamp-based extraction. + + Returns: + 2: segments by timestamps (combos 1 and 2) + 3: speaker diarization segments + 4: speaker-segments by timestamps + """ + if not entries: + return 2 + + first = entries[0] + has_speaker = "speaker_id" in first + has_diar = "diar_segments" in first + + if has_speaker and has_diar: + return 3 + if has_speaker: + return 4 + return 2 + + +def load_manifest(manifest_path: str) -> list: + """Load a single manifest.jsonl file and return list of entries.""" + entries = [] + with open(manifest_path) as f: + for line_num, raw_line in enumerate(f, 1): + line = raw_line.strip() + if not line: + continue + try: + entries.append(json.loads(line)) + except json.JSONDecodeError as e: + logger.warning(f"Failed to parse line {line_num} in {manifest_path}: {e}") + return entries + + +def load_manifests(input_path: str, output_dir: str) -> list: + """Load entries from a single jsonl file or a directory of jsonl files.""" + if os.path.isfile(input_path): + return load_manifest(input_path) + + if not os.path.isdir(input_path): + logger.error(f"Input path not found: {input_path}") + return [] + + jsonl_files = sorted(glob.glob(os.path.join(input_path, "*.jsonl"))) + if not jsonl_files: + logger.error(f"No .jsonl files found in {input_path}") + return [] + + logger.info(f"Found {len(jsonl_files)} jsonl files in {input_path}") + + all_entries = [] + for jf in jsonl_files: + all_entries.extend(load_manifest(jf)) + + logger.info(f"Combined {len(all_entries)} entries from {len(jsonl_files)} file(s)") + + if all_entries: + os.makedirs(output_dir, exist_ok=True) + combined_path = os.path.join(output_dir, "manifest.jsonl") + with open(combined_path, "w") as f: + f.writelines(json.dumps(e) + "\n" for e in all_entries) + logger.info(f"Saved combined manifest to {combined_path}") + + return all_entries + + +def _write_metadata_csv(output_dir: str, metadata_rows: list[dict]) -> str: + """Write metadata.csv from collected metadata rows.""" + if not metadata_rows: + return "" + + all_keys: list[str] = [] + seen: set[str] = set() + for row in metadata_rows: + for k in row: + if k not in seen: + all_keys.append(k) + seen.add(k) + + csv_path = os.path.join(output_dir, "metadata.csv") + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=all_keys) + writer.writeheader() + writer.writerows(metadata_rows) + + return csv_path + + +# ------------------------------------------------------------------ +# Stage +# ------------------------------------------------------------------ + + +@dataclass +class SegmentExtractionStage(ProcessingStage[AudioTask, AudioTask]): + """Extract audio segments from original files based on manifest entries. + + Receives ``AudioTask`` objects whose ``data`` dicts are manifest + entries (produced by ``TimestampMapperStage``). For each entry the + stage reads the audio slice from the original file and writes it as + a standalone segment file. + + The pipeline combo is auto-detected from the first entry in each + batch. Entries are grouped by ``original_file`` so each source is + opened only once per batch. + + This is an IO stage: ``process()`` raises ``NotImplementedError`` + and all work is done in ``process_batch()``, following the same + pattern as ``AudioToDocumentStage`` and ``ALMManifestWriterStage``. + + Args: + output_dir: Directory where extracted segment files are written. + output_format: Audio format — ``wav``, ``flac``, or ``ogg``. + """ + + name: str = "SegmentExtraction" + output_dir: str = "" + output_format: str = DEFAULT_OUTPUT_FORMAT + batch_size: int = 64 + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + def __post_init__(self) -> None: + super().__init__() + if not self.output_dir: + msg = "output_dir is required for SegmentExtractionStage" + raise ValueError(msg) + if self.output_format not in SOUNDFILE_FORMATS: + msg = f"output_format must be one of {list(SOUNDFILE_FORMATS)}, got {self.output_format!r}" + raise ValueError(msg) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], ["original_file"] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], ["extracted_path"] + + def process(self, task: AudioTask) -> AudioTask: + msg = "SegmentExtractionStage only supports process_batch" + raise NotImplementedError(msg) + + def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: + if len(tasks) == 0: + return [] + + os.makedirs(self.output_dir, exist_ok=True) + + entries = [t.data for t in tasks] + combo = detect_combo(entries) + + extractors = { + 2: self._extract_by_timestamps, + 3: self._extract_speaker_diar, + 4: self._extract_speaker_timestamps, + } + extracted, total_dur, speaker_counts, metadata_rows = extractors[combo](entries) + + _write_metadata_csv(self.output_dir, metadata_rows) + + logger.info( + f"[{self.name}] Extracted {extracted} segments " + f"({total_dur:.1f}s) from {len(tasks)} entries" + ) + if speaker_counts: + for speaker, count in sorted(speaker_counts.items()): + logger.debug(f" {speaker}: {count} segments") + + return tasks + + # ------------------------------------------------------------------ + # Combo extractors (instance methods using self.output_dir/format) + # ------------------------------------------------------------------ + + def _extract_by_timestamps( + self, entries: list[dict], + ) -> tuple[int, float, dict[str, int], list[dict]]: + """Combo 2: extract by original_start_ms / original_end_ms.""" + counter: dict[str, int] = defaultdict(int) + + def _make_filename(name: str, _entry: dict, _seg_idx: int) -> str: + idx = counter[name] + counter[name] += 1 + return f"{name}_segment_{idx:03d}.{self.output_format}" + + return self._extract_file_segments( + entries, + sort_key=lambda x: x.get("original_start_ms", 0), + get_intervals=_intervals_from_timestamps, + make_filename=_make_filename, + ) + + def _extract_speaker_diar( + self, entries: list[dict], + ) -> tuple[int, float, dict[str, int], list[dict]]: + """Combo 3: extract each diar_segment per speaker.""" + + def _make_filename(name: str, entry: dict, seg_idx: int) -> str: + _, speaker_num = _get_speaker_label(entry) + return f"{name}_speaker_{speaker_num}_segment_{seg_idx:03d}.{self.output_format}" + + return self._extract_file_segments( + entries, + sort_key=lambda x: x.get("speaker_id", ""), + get_intervals=_intervals_from_diar_segments, + make_filename=_make_filename, + ) + + def _extract_speaker_timestamps( + self, entries: list[dict], + ) -> tuple[int, float, dict[str, int], list[dict]]: + """Combo 4: extract speaker-segments by timestamps.""" + per_speaker_count: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) + + def _make_filename(name: str, entry: dict, _seg_idx: int) -> str: + speaker_id, speaker_num = _get_speaker_label(entry) + idx = per_speaker_count[name][speaker_id] + per_speaker_count[name][speaker_id] += 1 + return f"{name}_speaker_{speaker_num}_segment_{idx:03d}.{self.output_format}" + + return self._extract_file_segments( + entries, + sort_key=lambda x: (x.get("speaker_id", ""), x.get("original_start_ms", 0)), + get_intervals=_intervals_from_timestamps, + make_filename=_make_filename, + ) + + # ------------------------------------------------------------------ + # Shared extraction engine + # ------------------------------------------------------------------ + + def _extract_file_segments( + self, + entries: list[dict], + *, + sort_key: Callable[[dict], Any], + get_intervals: Callable[[dict], list[Interval]], + make_filename: Callable[[str, dict, int], str], + ) -> tuple[int, float, dict[str, int], list[dict]]: + """Group-by-file -> read -> write -> metadata loop.""" + by_file: dict[str, list] = defaultdict(list) + for entry in entries: + by_file[entry.get("original_file", "")].append(entry) + + extracted = 0 + total_dur = 0.0 + speaker_counts: dict[str, int] = defaultdict(int) + metadata_rows: list[dict] = [] + + for original_file, file_entries in by_file.items(): + if not os.path.exists(original_file): + logger.error(f"Original file not found: {original_file}") + continue + + info = sf.info(original_file) + original_name = Path(original_file).stem + file_entries.sort(key=sort_key) + logger.info(f"\nProcessing: {original_name} ({len(file_entries)} entries)") + + for entry in file_entries: + intervals = get_intervals(entry) + for seg_idx, (start_ms, end_ms, dur) in enumerate(intervals): + out_filename = make_filename(original_name, entry, seg_idx) + output_path = os.path.join(self.output_dir, out_filename) + + try: + audio = _read_segment(original_file, start_ms, end_ms, info.samplerate) + sf.write(output_path, audio, info.samplerate, subtype=SOUNDFILE_FORMATS[self.output_format]) + extracted += 1 + total_dur += dur + + speaker_id = entry.get("speaker_id") + if speaker_id: + speaker_counts[speaker_id] += 1 + + metadata_rows.append( + _base_metadata(out_filename, original_file, entry, seg_idx, start_ms, end_ms, dur) + ) + logger.debug(f" {out_filename} ({start_ms}-{end_ms}ms, {dur:.2f}s)") + except Exception as e: # noqa: BLE001 + logger.error(f" Failed to extract {out_filename}: {e}") + + return extracted, total_dur, speaker_counts, metadata_rows + + # ------------------------------------------------------------------ + # Standalone convenience methods (post-pipeline usage) + # ------------------------------------------------------------------ + + def extract_from_manifest(self, input_path: str) -> None: + """Load a manifest file (or directory of JSONL files) and extract all segments. + + This is a convenience method for standalone usage outside + of a pipeline. It handles manifest loading, combo detection, + CSV metadata, and summary JSON — equivalent to the old + ``extract_segments()`` function. + """ + os.makedirs(self.output_dir, exist_ok=True) + + logger.info(f"Loading manifest: {input_path}") + entries = load_manifests(input_path, self.output_dir) + logger.info(f"Found {len(entries)} entries total") + + if not entries: + logger.error("No entries found in manifest") + return + + combo = detect_combo(entries) + combo_names = { + 2: "Segments by timestamps", + 3: "Speaker diarization segments", + 4: "Speaker-segments by timestamps", + } + logger.info(f"Detected: {combo_names[combo]}") + + extractors = { + 2: self._extract_by_timestamps, + 3: self._extract_speaker_diar, + 4: self._extract_speaker_timestamps, + } + total_extracted, total_dur, speaker_counts, metadata_rows = extractors[combo](entries) + + csv_path = _write_metadata_csv(self.output_dir, metadata_rows) + + summary = { + "manifest_path": input_path, + "output_dir": self.output_dir, + "total_segments": total_extracted, + "total_duration_sec": round(total_dur, 2), + "output_format": self.output_format, + } + if speaker_counts: + summary["segments_by_speaker"] = dict(speaker_counts) + + summary_path = os.path.join(self.output_dir, "extraction_summary.json") + with open(summary_path, "w") as f: + json.dump(summary, f, indent=2) + + logger.info(f"\n{'=' * 60}") + logger.info("EXTRACTION COMPLETE") + logger.info(f"{'=' * 60}") + logger.info(f" Combo: {combo_names[combo]}") + logger.info(f" Total segments: {total_extracted}") + logger.info(f" Total duration: {total_dur:.2f}s ({total_dur / 60:.1f} min)") + logger.info(f" Output: {self.output_dir}") + logger.info(f" Format: {self.output_format}") + if speaker_counts: + logger.info(" Segments by speaker:") + for speaker, count in sorted(speaker_counts.items()): + logger.info(f" {speaker}: {count} segments") + if csv_path: + logger.info(f" Metadata CSV: {csv_path}") + logger.info(f" Summary: {summary_path}") + + +# ------------------------------------------------------------------ +# Backward-compatible free functions (delegate to stage) +# ------------------------------------------------------------------ + + +def extract_segments_by_timestamps( + entries: list, output_dir: str, output_format: str, +) -> tuple[int, float, dict[str, int], list[dict]]: + """Extract segments by original_start_ms / original_end_ms, sorted by start time.""" + stage = SegmentExtractionStage(output_dir=output_dir, output_format=output_format) + return stage._extract_by_timestamps(entries) + + +def extract_speaker_diar_segments( + entries: list, output_dir: str, output_format: str, +) -> tuple[int, float, dict[str, int], list[dict]]: + """Extract individual speaking intervals from diar_segments per speaker.""" + stage = SegmentExtractionStage(output_dir=output_dir, output_format=output_format) + return stage._extract_speaker_diar(entries) + + +def extract_speaker_segments_by_timestamps( + entries: list, output_dir: str, output_format: str, +) -> tuple[int, float, dict[str, int], list[dict]]: + """Extract speaker-segments using original_start_ms / original_end_ms.""" + stage = SegmentExtractionStage(output_dir=output_dir, output_format=output_format) + return stage._extract_speaker_timestamps(entries) + + +def extract_segments(input_path: str, output_dir: str, output_format: str = DEFAULT_OUTPUT_FORMAT) -> None: + """Extract segments from original audio files based on manifest.""" + stage = SegmentExtractionStage(output_dir=output_dir, output_format=output_format) + stage.extract_from_manifest(input_path) diff --git a/tests/stages/audio/io/test_extract_segments.py b/tests/stages/audio/io/test_extract_segments.py new file mode 100644 index 0000000000..333197cd1a --- /dev/null +++ b/tests/stages/audio/io/test_extract_segments.py @@ -0,0 +1,417 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for nemo_curator.stages.audio.io.extract_segments.""" + +import csv +import json +import os +from pathlib import Path + +import numpy as np +import pytest +import soundfile as sf + +from nemo_curator.stages.audio.io.extract_segments import ( + SegmentExtractionStage, + _base_metadata, + _extract_scores, + _get_speaker_label, + _intervals_from_diar_segments, + _intervals_from_timestamps, + _read_segment, + _write_metadata_csv, + detect_combo, + load_manifest, + load_manifests, +) +from nemo_curator.tasks import AudioTask + +SAMPLE_RATE = 16000 +DURATION_SEC = 5.0 +NUM_SAMPLES = int(SAMPLE_RATE * DURATION_SEC) + + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def wav_dir(tmp_path: Path) -> Path: + """Create a temp directory with two 5-second mono WAV files.""" + for name in ("file_a", "file_b"): + audio = np.random.default_rng(42).uniform(-0.5, 0.5, NUM_SAMPLES).astype(np.float32) + sf.write(str(tmp_path / f"{name}.wav"), audio, SAMPLE_RATE) + return tmp_path + + +def _wav_path(wav_dir: Path, name: str = "file_a") -> str: + return str(wav_dir / f"{name}.wav") + + +def _write_manifest(path: Path, entries: list[dict]) -> str: + manifest = str(path / "manifest.jsonl") + with open(manifest, "w") as f: + f.writelines(json.dumps(e) + "\n" for e in entries) + return manifest + + +# ------------------------------------------------------------------ +# Pure helper functions +# ------------------------------------------------------------------ + + +class TestDetectCombo: + def test_empty_entries(self) -> None: + assert detect_combo([]) == 2 + + def test_no_speaker_no_diar(self) -> None: + assert detect_combo([{"original_file": "/a.wav", "original_start_ms": 0}]) == 2 + + def test_speaker_and_diar(self) -> None: + assert detect_combo([{"speaker_id": "speaker_0", "diar_segments": [[1.0, 2.0]]}]) == 3 + + def test_speaker_no_diar(self) -> None: + assert detect_combo([{"speaker_id": "speaker_0", "original_start_ms": 0}]) == 4 + + +class TestExtractScores: + def test_filters_structural_keys_and_rounds(self) -> None: + entry = { + "original_file": "/a.wav", + "duration": 5.0, + "speaker_id": "speaker_0", + "utmos_mos": 4.12345, + "custom_field": "hello", + } + scores = _extract_scores(entry) + assert "original_file" not in scores + assert "duration" not in scores + assert "speaker_id" not in scores + assert scores["utmos_mos"] == 4.1235 + assert scores["custom_field"] == "hello" + + def test_empty_entry(self) -> None: + assert _extract_scores({}) == {} + + +class TestGetSpeakerLabel: + def test_standard_format(self) -> None: + assert _get_speaker_label({"speaker_id": "speaker_2"}) == ("speaker_2", "2") + + def test_missing_speaker(self) -> None: + assert _get_speaker_label({}) == ("unknown", "unknown") + + def test_non_standard_id(self) -> None: + assert _get_speaker_label({"speaker_id": "alice"}) == ("alice", "alice") + + +class TestIntervalsFromTimestamps: + def test_basic(self) -> None: + assert _intervals_from_timestamps({"original_start_ms": 1000, "original_end_ms": 3000, "duration": 2.0}) == [(1000, 3000, 2.0)] + + def test_computed_duration(self) -> None: + assert _intervals_from_timestamps({"original_start_ms": 500, "original_end_ms": 2500}) == [(500, 2500, 2.0)] + + def test_missing_keys_default_zero(self) -> None: + assert _intervals_from_timestamps({}) == [(0, 0, 0.0)] + + +class TestIntervalsFromDiarSegments: + def test_basic(self) -> None: + result = _intervals_from_diar_segments({"diar_segments": [[1.0, 2.5], [3.0, 4.0]]}) + assert result == [(1000, 2500, 1.5), (3000, 4000, 1.0)] + + def test_empty_diar(self) -> None: + assert _intervals_from_diar_segments({"speaker_id": "speaker_0"}) == [] + + def test_sorted_output(self) -> None: + result = _intervals_from_diar_segments({"diar_segments": [[3.0, 4.0], [1.0, 2.0]]}) + assert result[0][0] < result[1][0] + + +class TestReadSegment: + def test_reads_correct_slice(self, wav_dir: Path) -> None: + filepath = _wav_path(wav_dir) + original, sr = sf.read(filepath) + expected = original[int(1.0 * sr):int(2.0 * sr)] + + result = _read_segment(filepath, 1000, 2000, sr) + np.testing.assert_array_almost_equal(result, expected, decimal=4) + + def test_full_file(self, wav_dir: Path) -> None: + filepath = _wav_path(wav_dir) + original, sr = sf.read(filepath) + + result = _read_segment(filepath, 0, 5000, sr) + assert len(result) == len(original) + + def test_zero_duration(self, wav_dir: Path) -> None: + result = _read_segment(_wav_path(wav_dir), 1000, 1000, SAMPLE_RATE) + assert len(result) == 0 + + +class TestBaseMetadata: + def test_without_speaker(self) -> None: + entry = {"original_start_ms": 0, "original_end_ms": 2000, "duration": 2.0, "utmos_mos": 4.1} + row = _base_metadata("out.wav", "/a.wav", entry, 0, 0, 2000, 2.0) + assert row["filename"] == "out.wav" + assert row["start_sec"] == 0.0 + assert row["end_sec"] == 2.0 + assert row["utmos_mos"] == 4.1 + assert "speaker_id" not in row + + def test_with_speaker(self) -> None: + entry = {"speaker_id": "speaker_0", "num_speakers": 3, "original_start_ms": 0, "original_end_ms": 1000} + row = _base_metadata("out.wav", "/a.wav", entry, 0, 0, 1000, 1.0) + assert row["speaker_id"] == "speaker_0" + assert row["num_speakers"] == 3 + + +# ------------------------------------------------------------------ +# Manifest loading +# ------------------------------------------------------------------ + + +class TestLoadManifest: + def test_load_valid(self, tmp_path: Path) -> None: + path = _write_manifest(tmp_path, [{"a": 1}, {"b": 2}]) + entries = load_manifest(path) + assert len(entries) == 2 + assert entries[0] == {"a": 1} + + def test_skip_empty_and_malformed_lines(self, tmp_path: Path) -> None: + p = str(tmp_path / "m.jsonl") + with open(p, "w") as f: + f.write('{"a":1}\n\nNOT JSON\n{"b":2}\n') + assert len(load_manifest(p)) == 2 + + def test_empty_file(self, tmp_path: Path) -> None: + p = str(tmp_path / "empty.jsonl") + with open(p, "w") as f: + f.write("") + assert load_manifest(p) == [] + + +class TestLoadManifests: + def test_directory_of_jsonl(self, tmp_path: Path) -> None: + manifest_dir = tmp_path / "manifests" + manifest_dir.mkdir() + for i in range(3): + with open(str(manifest_dir / f"part_{i}.jsonl"), "w") as f: + f.write(json.dumps({"idx": i}) + "\n") + out_dir = str(tmp_path / "out") + entries = load_manifests(str(manifest_dir), out_dir) + assert len(entries) == 3 + assert os.path.exists(os.path.join(out_dir, "manifest.jsonl")) + + def test_nonexistent_path(self, tmp_path: Path) -> None: + assert load_manifests(str(tmp_path / "nope"), str(tmp_path / "out")) == [] + + +# ------------------------------------------------------------------ +# CSV metadata output +# ------------------------------------------------------------------ + + +class TestWriteMetadataCsv: + def test_writes_csv(self, tmp_path: Path) -> None: + rows = [ + {"filename": "a.wav", "duration": 1.0, "utmos_mos": 4.2}, + {"filename": "b.wav", "duration": 2.0, "sigmos_ovrl": 3.5}, + ] + csv_path = _write_metadata_csv(str(tmp_path), rows) + with open(csv_path) as f: + read_rows = list(csv.DictReader(f)) + assert len(read_rows) == 2 + + def test_empty_rows_no_file(self, tmp_path: Path) -> None: + assert _write_metadata_csv(str(tmp_path), []) == "" + + +# ------------------------------------------------------------------ +# SegmentExtractionStage — init & interface +# ------------------------------------------------------------------ + + +class TestSegmentExtractionStageInit: + def test_valid_construction(self, tmp_path: Path) -> None: + stage = SegmentExtractionStage(output_dir=str(tmp_path), output_format="flac") + assert stage.name == "SegmentExtraction" + assert stage.output_format == "flac" + assert stage.batch_size == 64 + + def test_missing_output_dir_raises(self) -> None: + with pytest.raises(ValueError, match="output_dir is required"): + SegmentExtractionStage(output_dir="") + + def test_invalid_format_raises(self, tmp_path: Path) -> None: + with pytest.raises(ValueError, match="output_format must be one of"): + SegmentExtractionStage(output_dir=str(tmp_path), output_format="mp3") + + def test_process_raises_not_implemented(self, tmp_path: Path) -> None: + stage = SegmentExtractionStage(output_dir=str(tmp_path)) + task = AudioTask(data={"original_file": "/a.wav"}, task_id="t", dataset_name="d") + with pytest.raises(NotImplementedError): + stage.process(task) + + def test_inputs_outputs(self, tmp_path: Path) -> None: + stage = SegmentExtractionStage(output_dir=str(tmp_path)) + assert stage.inputs() == ([], ["original_file"]) + assert stage.outputs() == ([], ["extracted_path"]) + + +# ------------------------------------------------------------------ +# SegmentExtractionStage — process_batch +# ------------------------------------------------------------------ + + +class TestSegmentExtractionStageProcessBatch: + def test_empty_batch(self, tmp_path: Path) -> None: + stage = SegmentExtractionStage(output_dir=str(tmp_path / "out")) + assert stage.process_batch([]) == [] + + def test_combo2_timestamps(self, wav_dir: Path, tmp_path: Path) -> None: + out_dir = str(tmp_path / "extracted") + stage = SegmentExtractionStage(output_dir=out_dir) + tasks = [ + AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 0, "original_end_ms": 2000, "duration": 2.0}, task_id="t1", dataset_name="test"), + AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 2500, "original_end_ms": 4500, "duration": 2.0}, task_id="t2", dataset_name="test"), + ] + result = stage.process_batch(tasks) + assert len(result) == 2 + assert os.path.exists(os.path.join(out_dir, "file_a_segment_000.wav")) + assert os.path.exists(os.path.join(out_dir, "file_a_segment_001.wav")) + + def test_combo3_diar_segments(self, wav_dir: Path, tmp_path: Path) -> None: + out_dir = str(tmp_path / "extracted") + stage = SegmentExtractionStage(output_dir=out_dir) + tasks = [ + AudioTask( + data={"original_file": _wav_path(wav_dir), "speaker_id": "speaker_0", "num_speakers": 2, "diar_segments": [[0.5, 1.5], [2.0, 3.0]]}, + task_id="t1", dataset_name="test", + ), + ] + result = stage.process_batch(tasks) + assert len(result) == 1 + assert os.path.exists(os.path.join(out_dir, "file_a_speaker_0_segment_000.wav")) + assert os.path.exists(os.path.join(out_dir, "file_a_speaker_0_segment_001.wav")) + + def test_combo4_speaker_timestamps(self, wav_dir: Path, tmp_path: Path) -> None: + out_dir = str(tmp_path / "extracted") + stage = SegmentExtractionStage(output_dir=out_dir) + tasks = [ + AudioTask(data={"original_file": _wav_path(wav_dir), "speaker_id": "speaker_0", "original_start_ms": 0, "original_end_ms": 1000, "duration": 1.0}, task_id="t1", dataset_name="test"), + AudioTask(data={"original_file": _wav_path(wav_dir), "speaker_id": "speaker_1", "original_start_ms": 1500, "original_end_ms": 2500, "duration": 1.0}, task_id="t2", dataset_name="test"), + ] + result = stage.process_batch(tasks) + assert len(result) == 2 + assert os.path.exists(os.path.join(out_dir, "file_a_speaker_0_segment_000.wav")) + assert os.path.exists(os.path.join(out_dir, "file_a_speaker_1_segment_000.wav")) + + def test_flac_format(self, wav_dir: Path, tmp_path: Path) -> None: + out_dir = str(tmp_path / "extracted") + stage = SegmentExtractionStage(output_dir=out_dir, output_format="flac") + tasks = [ + AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 0, "original_end_ms": 1000, "duration": 1.0}, task_id="t1", dataset_name="test"), + ] + stage.process_batch(tasks) + assert os.path.exists(os.path.join(out_dir, "file_a_segment_000.flac")) + + def test_missing_original_file_skipped(self, tmp_path: Path) -> None: + out_dir = str(tmp_path / "extracted") + stage = SegmentExtractionStage(output_dir=out_dir) + tasks = [ + AudioTask(data={"original_file": "/nonexistent/audio.wav", "original_start_ms": 0, "original_end_ms": 1000, "duration": 1.0}, task_id="t1", dataset_name="test"), + ] + result = stage.process_batch(tasks) + assert len(result) == 1 + assert not os.path.exists(os.path.join(out_dir, "metadata.csv")) + + def test_metadata_csv_written(self, wav_dir: Path, tmp_path: Path) -> None: + out_dir = str(tmp_path / "extracted") + stage = SegmentExtractionStage(output_dir=out_dir) + tasks = [ + AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 0, "original_end_ms": 1000, "duration": 1.0, "utmos_mos": 4.2}, task_id="t1", dataset_name="test"), + ] + stage.process_batch(tasks) + with open(os.path.join(out_dir, "metadata.csv")) as f: + rows = list(csv.DictReader(f)) + assert len(rows) == 1 + assert float(rows[0]["utmos_mos"]) == 4.2 + + def test_audio_content_matches_source(self, wav_dir: Path, tmp_path: Path) -> None: + original, sr = sf.read(_wav_path(wav_dir)) + expected_slice = original[int(1.0 * sr):int(2.0 * sr)] + + out_dir = str(tmp_path / "extracted") + stage = SegmentExtractionStage(output_dir=out_dir) + tasks = [ + AudioTask(data={"original_file": _wav_path(wav_dir), "original_start_ms": 1000, "original_end_ms": 2000, "duration": 1.0}, task_id="t1", dataset_name="test"), + ] + stage.process_batch(tasks) + + extracted, _ = sf.read(os.path.join(out_dir, "file_a_segment_000.wav")) + np.testing.assert_array_almost_equal(extracted, expected_slice, decimal=4) + + +# ------------------------------------------------------------------ +# SegmentExtractionStage — extract_from_manifest +# ------------------------------------------------------------------ + + +class TestExtractFromManifest: + def test_end_to_end(self, wav_dir: Path, tmp_path: Path) -> None: + entries = [ + {"original_file": _wav_path(wav_dir), "original_start_ms": 0, "original_end_ms": 2000, "duration": 2.0, "utmos_mos": 4.0}, + {"original_file": _wav_path(wav_dir), "original_start_ms": 2500, "original_end_ms": 4500, "duration": 2.0}, + ] + manifest_path = _write_manifest(tmp_path, entries) + out_dir = str(tmp_path / "output") + + stage = SegmentExtractionStage(output_dir=out_dir) + stage.extract_from_manifest(manifest_path) + + assert os.path.exists(os.path.join(out_dir, "file_a_segment_000.wav")) + assert os.path.exists(os.path.join(out_dir, "file_a_segment_001.wav")) + assert os.path.exists(os.path.join(out_dir, "metadata.csv")) + assert os.path.exists(os.path.join(out_dir, "extraction_summary.json")) + + with open(os.path.join(out_dir, "extraction_summary.json")) as f: + summary = json.load(f) + assert summary["total_segments"] == 2 + assert abs(summary["total_duration_sec"] - 4.0) < 0.01 + + def test_empty_manifest(self, tmp_path: Path) -> None: + manifest_path = _write_manifest(tmp_path, []) + out_dir = str(tmp_path / "output") + stage = SegmentExtractionStage(output_dir=out_dir) + stage.extract_from_manifest(manifest_path) + assert not os.path.exists(os.path.join(out_dir, "extraction_summary.json")) + + def test_directory_input(self, wav_dir: Path, tmp_path: Path) -> None: + manifest_dir = tmp_path / "manifests" + manifest_dir.mkdir() + for i, start in enumerate([0, 2000]): + with open(str(manifest_dir / f"part_{i}.jsonl"), "w") as f: + f.write(json.dumps({"original_file": _wav_path(wav_dir), "original_start_ms": start, "original_end_ms": start + 1000, "duration": 1.0}) + "\n") + + out_dir = str(tmp_path / "output") + stage = SegmentExtractionStage(output_dir=out_dir) + stage.extract_from_manifest(str(manifest_dir)) + + assert os.path.exists(os.path.join(out_dir, "manifest.jsonl")) + with open(os.path.join(out_dir, "extraction_summary.json")) as f: + assert json.load(f)["total_segments"] == 2 diff --git a/tutorials/audio/readspeech/extract_segments.py b/tutorials/audio/readspeech/extract_segments.py index 3758f672f7..72d024b202 100755 --- a/tutorials/audio/readspeech/extract_segments.py +++ b/tutorials/audio/readspeech/extract_segments.py @@ -13,470 +13,36 @@ # limitations under the License. """ -Segment Extraction Script +Segment Extraction CLI — thin wrapper around the package implementation. -Reads manifest jsonl file(s) and extracts audio segments from original files. +Reads manifest JSONL file(s) and extracts audio segments from original files. Automatically detects the pipeline combo from the manifest schema and applies -the appropriate extraction strategy: +the appropriate extraction strategy. - Combo 1 (no VAD, no speaker): - Extracts the full file as a single segment (start=0, end=file duration). - Output: {original_filename}_segment_000.{format} - - Combo 2 (VAD only): - Extracts each VAD speech segment by original_start_ms / original_end_ms. - Output: {original_filename}_segment_{NNN}.{format} - Segments are numbered in ascending order of start time. - - Combo 3 (speaker only): - Extracts each speaking interval from diar_segments per speaker. - Output: {original_filename}_speaker_{X}_segment_{NNN}.{format} - Segments are numbered per speaker in ascending order. - - Combo 4 (VAD + speaker): - Extracts each speaker-segment by original_start_ms / original_end_ms. - Output: {original_filename}_speaker_{X}_segment_{NNN}.{format} - Segments are numbered per speaker in ascending order of start time. - -Input can be: - - A single manifest.jsonl file - - A directory containing multiple .jsonl files - -Supports configurable output format: wav, flac, ogg (via soundfile). +Each segment is saved with naming convention: + With speaker separation: {original_filename}_speaker_{x}_segment_{y}.{format} + Without speaker separation: {original_filename}_segment_{y}.{format} Usage: python extract_segments.py --manifest manifest.jsonl --output-dir extracted/ python extract_segments.py --manifest /path/to/result_dir/ --output-dir out/ python extract_segments.py --manifest result_dir/ --output-dir out/ --output-format flac + +See ``nemo_curator.stages.audio.io.extract_segments`` for the full API. """ import argparse -import csv -import glob -import json import os import sys -from collections import defaultdict -from collections.abc import Callable -from pathlib import Path -from typing import Any -import numpy as np -import soundfile as sf from loguru import logger -DEFAULT_OUTPUT_FORMAT = "wav" - -SOUNDFILE_FORMATS = { - "wav": "PCM_16", - "flac": "PCM_16", - "ogg": "VORBIS", -} - -_CSV_STRUCTURAL_KEYS = frozenset( - { - "filename", - "original_file", - "original_start_ms", - "original_end_ms", - "duration_ms", - "start_sec", - "end_sec", - "duration", - "segment_index", - "speaker_id", - "num_speakers", - "speaking_duration", - "diar_segments", - } +from nemo_curator.stages.audio.io.extract_segments import ( + DEFAULT_OUTPUT_FORMAT, + SegmentExtractionStage, ) -def _extract_scores(entry: dict) -> dict: - """Extract quality/filter score fields from a manifest entry. - - Returns all keys that are not structural CSV columns (timestamps, - duration, speaker info), with float values rounded for readability. - Since TimestampMapper already whitelist-filters the manifest output, - anything remaining is a quality score or user-defined field. - """ - return {k: round(v, 4) if isinstance(v, float) else v for k, v in entry.items() if k not in _CSV_STRUCTURAL_KEYS} - - -def load_manifest(manifest_path: str) -> list: - """Load a single manifest.jsonl file and return list of entries.""" - entries = [] - with open(manifest_path) as f: - for line_num, raw_line in enumerate(f, 1): - line = raw_line.strip() - if not line: - continue - try: - entries.append(json.loads(line)) - except json.JSONDecodeError as e: - logger.warning(f"Failed to parse line {line_num} in {manifest_path}: {e}") - return entries - - -def load_manifests(input_path: str, output_dir: str) -> list: - """Load entries from a single jsonl file or a directory of jsonl files.""" - if os.path.isfile(input_path): - return load_manifest(input_path) - - if not os.path.isdir(input_path): - logger.error(f"Input path not found: {input_path}") - return [] - - jsonl_files = sorted(glob.glob(os.path.join(input_path, "*.jsonl"))) - if not jsonl_files: - logger.error(f"No .jsonl files found in {input_path}") - return [] - - logger.info(f"Found {len(jsonl_files)} jsonl files in {input_path}") - - all_entries = [] - for jf in jsonl_files: - all_entries.extend(load_manifest(jf)) - - logger.info(f"Combined {len(all_entries)} entries from {len(jsonl_files)} file(s)") - - if all_entries: - os.makedirs(output_dir, exist_ok=True) - combined_path = os.path.join(output_dir, "manifest.jsonl") - with open(combined_path, "w") as f: - f.writelines(json.dumps(e) + "\n" for e in all_entries) - logger.info(f"Saved combined manifest to {combined_path}") - - return all_entries - - -def detect_combo(entries: list) -> int: - """Detect which pipeline combo produced the manifest. - - Returns 2, 3, or 4. Since TimestampMapper always emits - ``original_start_ms``/``original_end_ms``, combos 1 and 2 are - indistinguishable and both use timestamp-based extraction. - - Returns: - 2: segments by timestamps (combos 1 and 2) - 3: speaker diarization segments - 4: speaker-segments by timestamps - """ - if not entries: - return 2 - - first = entries[0] - has_speaker = "speaker_id" in first - has_diar = "diar_segments" in first - - if has_speaker and has_diar: - return 3 - if has_speaker: - return 4 - return 2 - - -def _write_segment(output_path: str, audio: np.ndarray, sample_rate: int, output_format: str) -> None: - """Write a single audio segment to disk.""" - sf.write(output_path, audio, sample_rate, subtype=SOUNDFILE_FORMATS[output_format]) - - -def _read_segment(filepath: str, start_ms: int, end_ms: int, sample_rate: int) -> np.ndarray: - """Read a slice of audio from a file.""" - start_sample = int(start_ms * sample_rate / 1000) - end_sample = int(end_ms * sample_rate / 1000) - audio, _ = sf.read(filepath, start=start_sample, stop=end_sample, dtype="float32") - return audio - - -# ------------------------------------------------------------------ -# Shared extraction engine -# ------------------------------------------------------------------ - -Interval = tuple[int, int, float] # (start_ms, end_ms, duration_sec) - - -def _get_speaker_label(entry: dict) -> tuple[str, str]: - """Return (speaker_id, speaker_num) from a manifest entry.""" - speaker_id = entry.get("speaker_id", "unknown") - speaker_num = speaker_id.replace("speaker_", "") if "speaker_" in speaker_id else speaker_id - return speaker_id, speaker_num - - -def _extract_file_segments( # noqa: PLR0913 - entries: list, - output_dir: str, - output_format: str, - *, - sort_key: Callable[[dict], Any], - get_intervals: Callable[[dict], list[Interval]], - make_filename: Callable[[str, dict, int], str], - make_metadata: Callable[[str, str, dict, int, int, int, float], dict], -) -> tuple[int, float, dict[str, int], list[dict]]: - """Shared group-by-file -> read -> write -> metadata loop. - - Args: - entries: Manifest entries to extract. - output_dir: Where to write extracted audio files. - output_format: Audio format (wav, flac, ogg). - sort_key: How to sort entries within each file group. - get_intervals: Given an entry, return a list of (start_ms, end_ms, dur_sec). - make_filename: Given (original_name, entry, segment_index), return filename. - make_metadata: Given (filename, original_file, entry, seg_idx, start_ms, - end_ms, dur), return the metadata dict for this segment. - """ - by_file: dict[str, list] = defaultdict(list) - for entry in entries: - by_file[entry.get("original_file", "")].append(entry) - - extracted = 0 - total_dur = 0.0 - speaker_counts: dict[str, int] = defaultdict(int) - metadata_rows: list[dict] = [] - - for original_file, file_entries in by_file.items(): - if not os.path.exists(original_file): - logger.error(f"Original file not found: {original_file}") - continue - - info = sf.info(original_file) - original_name = Path(original_file).stem - file_entries.sort(key=sort_key) - logger.info(f"\nProcessing: {original_name} ({len(file_entries)} entries)") - - for entry in file_entries: - intervals = get_intervals(entry) - for seg_idx, (start_ms, end_ms, dur) in enumerate(intervals): - out_filename = make_filename(original_name, entry, seg_idx) - output_path = os.path.join(output_dir, out_filename) - - try: - audio = _read_segment(original_file, start_ms, end_ms, info.samplerate) - _write_segment(output_path, audio, info.samplerate, output_format) - extracted += 1 - total_dur += dur - - speaker_id = entry.get("speaker_id") - if speaker_id: - speaker_counts[speaker_id] += 1 - - metadata_rows.append( - make_metadata(out_filename, original_file, entry, seg_idx, start_ms, end_ms, dur) - ) - logger.debug(f" {out_filename} ({start_ms}-{end_ms}ms, {dur:.2f}s)") - except Exception as e: # noqa: BLE001 - logger.error(f" Failed to extract {out_filename}: {e}") - - return extracted, total_dur, speaker_counts, metadata_rows - - -# ------------------------------------------------------------------ -# Combo-specific callables -# ------------------------------------------------------------------ - - -def _intervals_from_timestamps(entry: dict) -> list[Interval]: - start_ms = entry.get("original_start_ms", 0) - end_ms = entry.get("original_end_ms", 0) - dur = entry.get("duration", (end_ms - start_ms) / 1000) - return [(start_ms, end_ms, dur)] - - -def _intervals_from_diar_segments(entry: dict) -> list[Interval]: - diar_segments = entry.get("diar_segments", []) - if not diar_segments: - speaker_id = entry.get("speaker_id", "unknown") - logger.warning(f" {speaker_id}: no diar_segments, skipping") - return [] - return [ - (int(s * 1000), int(e * 1000), e - s) - for s, e in sorted(diar_segments, key=lambda x: x[0]) - ] - - -def _base_metadata( # noqa: PLR0913 - filename: str, original_file: str, entry: dict, - seg_idx: int, start_ms: int, end_ms: int, dur: float, -) -> dict: - row: dict = { - "filename": filename, - "original_file": original_file, - "segment_index": seg_idx, - "start_sec": round(start_ms / 1000, 3), - "end_sec": round(end_ms / 1000, 3), - "duration": round(dur, 3), - } - speaker_id = entry.get("speaker_id") - if speaker_id is not None: - row["speaker_id"] = speaker_id - num_speakers = entry.get("num_speakers") - if num_speakers is not None: - row["num_speakers"] = num_speakers - row.update(_extract_scores(entry)) - return row - - -# ------------------------------------------------------------------ -# Combos 1 & 2: extract segments by timestamps -# ------------------------------------------------------------------ - - -def extract_segments_by_timestamps( - entries: list, output_dir: str, output_format: str, -) -> tuple[int, float, dict[str, int], list[dict]]: - """Extract segments by original_start_ms / original_end_ms, sorted by start time.""" - counter: dict[str, int] = defaultdict(int) - - def _make_filename(name: str, _entry: dict, _seg_idx: int) -> str: - idx = counter[name] - counter[name] += 1 - return f"{name}_segment_{idx:03d}.{output_format}" - - return _extract_file_segments( - entries, output_dir, output_format, - sort_key=lambda x: x.get("original_start_ms", 0), - get_intervals=_intervals_from_timestamps, - make_filename=_make_filename, - make_metadata=_base_metadata, - ) - - -# ------------------------------------------------------------------ -# Combo 3: speaker only -- extract each diar_segment per speaker -# ------------------------------------------------------------------ - - -def extract_speaker_diar_segments( - entries: list, output_dir: str, output_format: str, -) -> tuple[int, float, dict[str, int], list[dict]]: - """Extract individual speaking intervals from diar_segments per speaker.""" - - def _make_filename(name: str, entry: dict, seg_idx: int) -> str: - _, speaker_num = _get_speaker_label(entry) - return f"{name}_speaker_{speaker_num}_segment_{seg_idx:03d}.{output_format}" - - return _extract_file_segments( - entries, output_dir, output_format, - sort_key=lambda x: x.get("speaker_id", ""), - get_intervals=_intervals_from_diar_segments, - make_filename=_make_filename, - make_metadata=_base_metadata, - ) - - -# ------------------------------------------------------------------ -# Combo 4: VAD + speaker -- extract each speaker-segment by timestamps -# ------------------------------------------------------------------ - - -def extract_speaker_segments_by_timestamps( - entries: list, output_dir: str, output_format: str, -) -> tuple[int, float, dict[str, int], list[dict]]: - """Extract speaker-segments using original_start_ms / original_end_ms.""" - per_speaker_count: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) - - def _make_filename(name: str, entry: dict, _seg_idx: int) -> str: - speaker_id, speaker_num = _get_speaker_label(entry) - idx = per_speaker_count[name][speaker_id] - per_speaker_count[name][speaker_id] += 1 - return f"{name}_speaker_{speaker_num}_segment_{idx:03d}.{output_format}" - - return _extract_file_segments( - entries, output_dir, output_format, - sort_key=lambda x: (x.get("speaker_id", ""), x.get("original_start_ms", 0)), - get_intervals=_intervals_from_timestamps, - make_filename=_make_filename, - make_metadata=_base_metadata, - ) - - -# ------------------------------------------------------------------ -# Main -# ------------------------------------------------------------------ - - -def _write_metadata_csv(output_dir: str, metadata_rows: list[dict]) -> str: - """Write metadata.csv from collected metadata rows.""" - if not metadata_rows: - return "" - - all_keys: list[str] = [] - seen: set[str] = set() - for row in metadata_rows: - for k in row: - if k not in seen: - all_keys.append(k) - seen.add(k) - - csv_path = os.path.join(output_dir, "metadata.csv") - with open(csv_path, "w", newline="") as f: - writer = csv.DictWriter(f, fieldnames=all_keys) - writer.writeheader() - writer.writerows(metadata_rows) - - return csv_path - - -def extract_segments(input_path: str, output_dir: str, output_format: str = DEFAULT_OUTPUT_FORMAT) -> None: - """Extract segments from original audio files based on manifest.""" - os.makedirs(output_dir, exist_ok=True) - - logger.info(f"Loading manifest: {input_path}") - entries = load_manifests(input_path, output_dir) - logger.info(f"Found {len(entries)} entries total") - - if not entries: - logger.error("No entries found in manifest") - return - - combo = detect_combo(entries) - combo_names = { - 2: "Segments by timestamps", - 3: "Speaker diarization segments", - 4: "Speaker-segments by timestamps", - } - logger.info(f"Detected: {combo_names[combo]}") - - extractors = { - 2: extract_segments_by_timestamps, - 3: extract_speaker_diar_segments, - 4: extract_speaker_segments_by_timestamps, - } - total_extracted, total_dur, speaker_counts, metadata_rows = extractors[combo](entries, output_dir, output_format) - - csv_path = _write_metadata_csv(output_dir, metadata_rows) - - summary = { - "manifest_path": input_path, - "output_dir": output_dir, - "total_segments": total_extracted, - "total_duration_sec": round(total_dur, 2), - "output_format": output_format, - } - if speaker_counts: - summary["segments_by_speaker"] = dict(speaker_counts) - - summary_path = os.path.join(output_dir, "extraction_summary.json") - with open(summary_path, "w") as f: - json.dump(summary, f, indent=2) - - logger.info(f"\n{'=' * 60}") - logger.info("EXTRACTION COMPLETE") - logger.info(f"{'=' * 60}") - logger.info(f" Combo: {combo_names[combo]}") - logger.info(f" Total segments: {total_extracted}") - logger.info(f" Total duration: {total_dur:.2f}s ({total_dur / 60:.1f} min)") - logger.info(f" Output: {output_dir}") - logger.info(f" Format: {output_format}") - if speaker_counts: - logger.info(" Segments by speaker:") - for speaker, count in sorted(speaker_counts.items()): - logger.info(f" {speaker}: {count} segments") - if csv_path: - logger.info(f" Metadata CSV: {csv_path}") - logger.info(f" Summary: {summary_path}") - - def main() -> int: parser = argparse.ArgumentParser(description="Extract audio segments from original files based on manifest") parser.add_argument( @@ -503,7 +69,8 @@ def main() -> int: return 1 logger.info(f"Output format: {args.output_format}") - extract_segments(input_path=args.manifest, output_dir=args.output_dir, output_format=args.output_format) + stage = SegmentExtractionStage(output_dir=args.output_dir, output_format=args.output_format) + stage.extract_from_manifest(input_path=args.manifest) return 0 From efeb66d24e80485ff2abd92af26698c131ed32bf Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Mon, 20 Apr 2026 08:17:47 -0700 Subject: [PATCH 09/11] fix: accumulate metadata rows across batches in SegmentExtractionStage --- nemo_curator/stages/audio/io/extract_segments.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_curator/stages/audio/io/extract_segments.py b/nemo_curator/stages/audio/io/extract_segments.py index cdf16e0171..c6216d980b 100644 --- a/nemo_curator/stages/audio/io/extract_segments.py +++ b/nemo_curator/stages/audio/io/extract_segments.py @@ -307,6 +307,7 @@ def __post_init__(self) -> None: if self.output_format not in SOUNDFILE_FORMATS: msg = f"output_format must be one of {list(SOUNDFILE_FORMATS)}, got {self.output_format!r}" raise ValueError(msg) + self._all_metadata_rows: list[dict] = [] def inputs(self) -> tuple[list[str], list[str]]: return [], ["original_file"] @@ -334,7 +335,8 @@ def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]: } extracted, total_dur, speaker_counts, metadata_rows = extractors[combo](entries) - _write_metadata_csv(self.output_dir, metadata_rows) + self._all_metadata_rows.extend(metadata_rows) + _write_metadata_csv(self.output_dir, self._all_metadata_rows) logger.info( f"[{self.name}] Extracted {extracted} segments " From 30d91b5ebef9196c3b7a1fdfdad110dd0a7323da Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Mon, 20 Apr 2026 10:19:53 -0700 Subject: [PATCH 10/11] fix: persistent counters for all combos and single-worker constraint --- .../stages/audio/io/extract_segments.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/nemo_curator/stages/audio/io/extract_segments.py b/nemo_curator/stages/audio/io/extract_segments.py index c6216d980b..fb8383ac85 100644 --- a/nemo_curator/stages/audio/io/extract_segments.py +++ b/nemo_curator/stages/audio/io/extract_segments.py @@ -308,6 +308,8 @@ def __post_init__(self) -> None: msg = f"output_format must be one of {list(SOUNDFILE_FORMATS)}, got {self.output_format!r}" raise ValueError(msg) self._all_metadata_rows: list[dict] = [] + self._segment_counter: dict[str, int] = defaultdict(int) + self._speaker_segment_counter: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) def inputs(self) -> tuple[list[str], list[str]]: return [], ["original_file"] @@ -315,6 +317,12 @@ def inputs(self) -> tuple[list[str], list[str]]: def outputs(self) -> tuple[list[str], list[str]]: return [], ["extracted_path"] + def num_workers(self) -> int | None: + return 1 + + def xenna_stage_spec(self) -> dict[str, Any]: + return {"num_workers": 1} + def process(self, task: AudioTask) -> AudioTask: msg = "SegmentExtractionStage only supports process_batch" raise NotImplementedError(msg) @@ -356,11 +364,10 @@ def _extract_by_timestamps( self, entries: list[dict], ) -> tuple[int, float, dict[str, int], list[dict]]: """Combo 2: extract by original_start_ms / original_end_ms.""" - counter: dict[str, int] = defaultdict(int) def _make_filename(name: str, _entry: dict, _seg_idx: int) -> str: - idx = counter[name] - counter[name] += 1 + idx = self._segment_counter[name] + self._segment_counter[name] += 1 return f"{name}_segment_{idx:03d}.{self.output_format}" return self._extract_file_segments( @@ -375,9 +382,11 @@ def _extract_speaker_diar( ) -> tuple[int, float, dict[str, int], list[dict]]: """Combo 3: extract each diar_segment per speaker.""" - def _make_filename(name: str, entry: dict, seg_idx: int) -> str: - _, speaker_num = _get_speaker_label(entry) - return f"{name}_speaker_{speaker_num}_segment_{seg_idx:03d}.{self.output_format}" + def _make_filename(name: str, entry: dict, _seg_idx: int) -> str: + speaker_id, speaker_num = _get_speaker_label(entry) + idx = self._speaker_segment_counter[name][speaker_id] + self._speaker_segment_counter[name][speaker_id] += 1 + return f"{name}_speaker_{speaker_num}_segment_{idx:03d}.{self.output_format}" return self._extract_file_segments( entries, @@ -390,12 +399,11 @@ def _extract_speaker_timestamps( self, entries: list[dict], ) -> tuple[int, float, dict[str, int], list[dict]]: """Combo 4: extract speaker-segments by timestamps.""" - per_speaker_count: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) def _make_filename(name: str, entry: dict, _seg_idx: int) -> str: speaker_id, speaker_num = _get_speaker_label(entry) - idx = per_speaker_count[name][speaker_id] - per_speaker_count[name][speaker_id] += 1 + idx = self._speaker_segment_counter[name][speaker_id] + self._speaker_segment_counter[name][speaker_id] += 1 return f"{name}_speaker_{speaker_num}_segment_{idx:03d}.{self.output_format}" return self._extract_file_segments( From 12e3fdcca0d06ae2ff19aab8a244b4f424368b57 Mon Sep 17 00:00:00 2001 From: shbhawsar Date: Mon, 20 Apr 2026 11:06:17 -0700 Subject: [PATCH 11/11] fix: NameError in run.py, lazy executor factory, reduce GPU allocations to fit single GPU Signed-off-by: shbhawsar --- .../audio_data_filter/default_config.yaml | 8 ++--- tutorials/audio/readspeech/pipeline.py | 34 ++++++++++++++----- tutorials/audio/readspeech/run.py | 29 ++++++++++++---- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/default_config.yaml b/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/default_config.yaml index dbd173c859..62823b9871 100644 --- a/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/default_config.yaml +++ b/nemo_curator/stages/audio/advanced_pipelines/audio_data_filter/default_config.yaml @@ -19,7 +19,7 @@ vad: min_interval_ms: 500 speech_pad_ms: 300 cpus: 1.0 - gpus: 0.3 + gpus: 0.1 band_filter: enable: true @@ -31,7 +31,7 @@ utmos: enable: true mos_threshold: 3.4 cpus: 1.0 - gpus: 0.5 + gpus: 0.2 sigmos: enable: true @@ -43,7 +43,7 @@ sigmos: loud_threshold: null reverb_threshold: null cpus: 1.0 - gpus: 0.5 + gpus: 0.2 concatenation: silence_duration_sec: 0.5 @@ -56,7 +56,7 @@ speaker_separation: gap_threshold: 0.1 buffer_time: 0.5 cpus: 1.0 - gpus: 1.0 + gpus: 0.4 timestamp_mapper: passthrough_keys: null diff --git a/tutorials/audio/readspeech/pipeline.py b/tutorials/audio/readspeech/pipeline.py index 9117ff20c1..4db5511e1b 100644 --- a/tutorials/audio/readspeech/pipeline.py +++ b/tutorials/audio/readspeech/pipeline.py @@ -51,20 +51,33 @@ """ import argparse +import importlib import os import shutil import sys from loguru import logger -from nemo_curator.backends.ray_data import RayDataExecutor -from nemo_curator.backends.xenna import XennaExecutor from nemo_curator.pipeline import Pipeline from nemo_curator.stages.audio import AudioDataFilterStage from nemo_curator.stages.audio.datasets.readspeech import CreateInitialManifestReadSpeechStage from nemo_curator.stages.audio.io.convert import AudioToDocumentStage from nemo_curator.stages.text.io.writer import JsonlWriter +_EXECUTOR_FACTORIES = { + "xenna": "nemo_curator.backends.xenna:XennaExecutor", + "ray_data": "nemo_curator.backends.ray_data:RayDataExecutor", +} + + +def _create_executor(backend: str, **kwargs) -> object: + if backend not in _EXECUTOR_FACTORIES: + msg = f"Unknown backend '{backend}'. Choose from: {list(_EXECUTOR_FACTORIES)}" + raise ValueError(msg) + module_path, class_name = _EXECUTOR_FACTORIES[backend].rsplit(":", 1) + mod = importlib.import_module(module_path) + return getattr(mod, class_name)(**kwargs) + def create_readspeech_pipeline(args: argparse.Namespace) -> Pipeline: """ @@ -201,7 +214,7 @@ def _build_parser() -> argparse.ArgumentParser: type=str, choices=["streaming", "batch"], default="streaming", - help="Xenna execution mode: 'streaming' (default) or 'batch'", + help="Xenna execution mode: 'streaming' (concurrent stages, default) or 'batch' (sequential stages)", ) parser.add_argument("--verbose", action="store_true", help="Verbose logging") parser.add_argument("--enable-vad", action="store_true", help="Enable VAD segmentation") @@ -257,6 +270,9 @@ def _log_config(args: argparse.Namespace) -> None: enabled.append("SpeakerSep") logger.info(f"Enabled Filters: {enabled or ['none']}") + logger.info(f"Backend: {args.backend}") + if args.backend == "xenna": + logger.info(f"Execution Mode: {args.execution_mode}") logger.info("=" * 70) @@ -278,14 +294,14 @@ def main() -> None: pipeline = create_readspeech_pipeline(args) logger.info(pipeline.describe()) - logger.info("Starting pipeline execution...") + executor_kwargs = {} + if args.backend == "xenna": + executor_kwargs["config"] = {"execution_mode": args.execution_mode} + executor = _create_executor(args.backend, **executor_kwargs) + + logger.info(f"Starting pipeline execution (backend: {args.backend})...") try: - executor = ( - RayDataExecutor() - if args.backend == "ray_data" - else XennaExecutor(config={"execution_mode": args.execution_mode}) - ) pipeline.run(executor) logger.info(f"Results written to {args.output_dir}/*.jsonl") diff --git a/tutorials/audio/readspeech/run.py b/tutorials/audio/readspeech/run.py index 2e380b7bac..8a7fbd2950 100644 --- a/tutorials/audio/readspeech/run.py +++ b/tutorials/audio/readspeech/run.py @@ -29,16 +29,29 @@ enable_utmos=true """ +import importlib import os import hydra from loguru import logger from omegaconf import DictConfig, OmegaConf -from nemo_curator.backends.ray_data import RayDataExecutor -from nemo_curator.backends.xenna import XennaExecutor from nemo_curator.pipeline import Pipeline +_EXECUTOR_FACTORIES = { + "xenna": "nemo_curator.backends.xenna:XennaExecutor", + "ray_data": "nemo_curator.backends.ray_data:RayDataExecutor", +} + + +def _create_executor(backend: str, **kwargs) -> object: + if backend not in _EXECUTOR_FACTORIES: + msg = f"Unknown backend '{backend}'. Choose from: {list(_EXECUTOR_FACTORIES)}" + raise ValueError(msg) + module_path, class_name = _EXECUTOR_FACTORIES[backend].rsplit(":", 1) + mod = importlib.import_module(module_path) + return getattr(mod, class_name)(**kwargs) + def create_pipeline_from_yaml(cfg: DictConfig) -> Pipeline: """Create pipeline from Hydra config.""" @@ -70,12 +83,14 @@ def main(cfg: DictConfig) -> None: os.makedirs(output_dir, exist_ok=True) backend = cfg.get("backend", "xenna") - if backend == "ray_data": - executor = RayDataExecutor() - else: + executor_kwargs = {} + if backend == "xenna": execution_mode = cfg.get("execution_mode", "streaming") - executor = XennaExecutor(config={"execution_mode": execution_mode}) - logger.info(f"Starting pipeline execution (mode: {execution_mode})...") + executor_kwargs["config"] = {"execution_mode": execution_mode} + logger.info(f"Starting pipeline execution (backend: {backend}, mode: {execution_mode})...") + else: + logger.info(f"Starting pipeline execution (backend: {backend})...") + executor = _create_executor(backend, **executor_kwargs) pipeline.run(executor) logger.info("\n" + "=" * 60)