@@ -291,3 +291,91 @@ def _tokenizer_and_align_labels(examples):
291291
292292 tokenized_datasets = _huggingface_pre_process_helper (dataset , _tokenizer_and_align_labels , max_samples , ** kwargs )
293293 return ClassificationDataset (tokenized_datasets , label_col = "label" , max_samples = max_samples )
294+
295+
296+ @Registry .register_pre_process ()
297+ def speech_transcription_pre_process (
298+ dataset ,
299+ audio_col : str = "audio" ,
300+ text_col : str = "text" ,
301+ sample_rate : int = 16000 ,
302+ max_samples : Optional [int ] = None ,
303+ limit : Optional [float ] = None ,
304+ seed : int = 42 ,
305+ ** kwargs ,
306+ ):
307+ """Pre-process data for speech transcription (ASR) evaluation.
308+
309+ Loads audio arrays and reference transcription text from a HuggingFace dataset.
310+ Returns a dataset of (audio_array, reference_text) pairs suitable for WER evaluation.
311+
312+ Args:
313+ dataset: HuggingFace dataset with audio and text columns.
314+ audio_col: Name of the audio column. Defaults to "audio".
315+ text_col: Name of the reference text column. Defaults to "text".
316+ sample_rate: Target sample rate for audio. Defaults to 16000.
317+ max_samples: Maximum number of samples (deprecated, use limit). Defaults to None.
318+ limit: Sampling limit following Olive convention:
319+ If >= 1: use first N samples.
320+ If 0 < limit < 1: randomly sample that percentage.
321+ If 0 or None: use all samples.
322+ seed: Random seed for percentage-based sampling. Defaults to 42.
323+ **kwargs: Additional arguments.
324+
325+ """
326+ from datasets import Audio
327+
328+ dataset = dataset .cast_column (audio_col , Audio (sampling_rate = sample_rate ))
329+
330+ # Apply sampling: prefer limit over max_samples
331+ effective_limit = limit if limit is not None else (max_samples if max_samples else 0 )
332+ if effective_limit and effective_limit != 0 :
333+ from random import Random
334+
335+ total = len (dataset )
336+ if 0 < effective_limit < 1 :
337+ n = max (1 , int (total * effective_limit ))
338+ rng = Random (seed )
339+ indices = sorted (rng .sample (range (total ), min (n , total )))
340+ dataset = dataset .select (indices )
341+ elif effective_limit >= 1 :
342+ n = min (int (effective_limit ), total )
343+ dataset = dataset .select (range (n ))
344+
345+ class SpeechTranscriptionDataset :
346+ """Dataset that returns (audio_array, reference_text) pairs.
347+
348+ Note: Use batch_size=1 in dataloader config as audio samples have variable lengths.
349+ """
350+
351+ def __init__ (self , hf_dataset , audio_column , text_column ):
352+ self .dataset = hf_dataset
353+ self .audio_column = audio_column
354+ self .text_column = text_column
355+
356+ def __len__ (self ):
357+ return len (self .dataset )
358+
359+ def __getitem__ (self , idx ):
360+ item = self .dataset [idx ]
361+ import numpy as np
362+
363+ audio_array = np .array (item [self .audio_column ]["array" ], dtype = np .float32 )
364+ reference_text = item [self .text_column ]
365+ return audio_array , reference_text
366+
367+ @staticmethod
368+ def collate_fn (batch ):
369+ """Collate variable-length audio batches. Use with batch_size=1 or pad audio."""
370+ import numpy as np
371+
372+ # batch_size=1 is expected for speech evaluation (variable-length audio)
373+ if len (batch ) == 1 :
374+ audio , text = batch [0 ]
375+ return (np .expand_dims (audio , 0 ), [text ])
376+ # For batch_size > 1, return as lists (no padding)
377+ audios = [item [0 ] for item in batch ]
378+ texts = [item [1 ] for item in batch ]
379+ return (audios , texts )
380+
381+ return SpeechTranscriptionDataset (dataset , audio_col , text_col )
0 commit comments