Skip to content

Commit 93b0ed4

Browse files
jiafatomCopilot
andauthored
Add WER and RTFx speech evaluation metrics to Olive evaluator (#2444)
## Summary Add Word Error Rate (WER) and Real-Time Factor (RTFx) as built-in accuracy sub-types in the Olive evaluator framework, enabling speech (ASR) model evaluation. - **WER**: Measures transcription accuracy (lower is better) - **RTFx**: Measures inference speed relative to audio duration (higher is better, e.g., RTFx=5 means 5x faster than real-time) ## Changes - **`olive/evaluator/metric.py`** — Add `WER` and `RTFX` to `AccuracySubType` enum - **`olive/evaluator/accuracy.py`** — Add `WordErrorRate` and `RealTimeFactor` classes - **`olive/evaluator/olive_evaluator.py`** — Add `_inference_text()` path in OnnxEvaluator and PyTorchEvaluator for string-based predictions with timing/audio tracking - **`olive/data/component/pre_process_data.py`** — Add `speech_transcription_pre_process` data component with Olive-style sampling (`--limit`/`--seed`) - **`olive/data/container/huggingface_container.py`** — Add `speech-transcription` task type - **`olive/olive_config.json`** — Add `speech` extra dependencies (`jiwer`, `librosa`, `soundfile`) - **`olive/evaluator/examples/`** — Example config and script for speech evaluation ## Testing Tested with Nemotron speech streaming model on LibriSpeech test.clean: - 64 samples: **WER = 2.87%**, RTFx = 5.44 - 5 samples: WER = 5.00%, RTFx = 5.56 Unit tests added for WordErrorRate and RealTimeFactor (7 tests, all passing). ## Usage ```json { "metrics": [{ "name": "speech_eval", "type": "accuracy", "sub_types": [ {"name": "wer", "higher_is_better": false}, {"name": "rtfx", "higher_is_better": true} ] }] } ``` --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 3b207db commit 93b0ed4

8 files changed

Lines changed: 750 additions & 4 deletions

File tree

docs/source/how-to/configure-workflows/metrics-configuration.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,28 @@ If you have multiple metrics to evaluate, you can configure them in the followin
128128
```{Note}
129129
If you have more than one metric, you need to specify `priority: {RANK}`, which Olive will use to determine the best model.
130130
```
131+
132+
## Speech Evaluation Metrics (WER and RTFx)
133+
134+
Olive supports Word Error Rate (WER) and Real-Time Factor (RTFx) as built-in accuracy sub-types for evaluating speech/ASR models.
135+
136+
### Using WER with the accuracy metric type
137+
138+
WER can be used as an accuracy sub-type when your data pipeline returns text predictions and references:
139+
140+
```json
141+
{
142+
"name": "speech_accuracy",
143+
"type": "accuracy",
144+
"data_config": "speech_data_config",
145+
"sub_types": [
146+
{"name": "wer", "priority": 1, "higher_is_better": false},
147+
{"name": "rtfx", "priority": 2, "higher_is_better": true}
148+
]
149+
}
150+
```
151+
152+
```{Note}
153+
- `wer` (Word Error Rate): Measures transcription errors. Lower is better (defaults to `higher_is_better: false`).
154+
- `rtfx` (Real-Time Factor): Ratio of audio duration to inference time. Higher means faster (defaults to `higher_is_better: true`).
155+
```

olive/data/component/pre_process_data.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,91 @@ def _tokenizer_and_align_labels(examples):
291291

292292
tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, max_samples, **kwargs)
293293
return ClassificationDataset(tokenized_datasets, label_col="label", max_samples=max_samples)
294+
295+
296+
@Registry.register_pre_process()
297+
def speech_transcription_pre_process(
298+
dataset,
299+
audio_col: str = "audio",
300+
text_col: str = "text",
301+
sample_rate: int = 16000,
302+
max_samples: Optional[int] = None,
303+
limit: Optional[float] = None,
304+
seed: int = 42,
305+
**kwargs,
306+
):
307+
"""Pre-process data for speech transcription (ASR) evaluation.
308+
309+
Loads audio arrays and reference transcription text from a HuggingFace dataset.
310+
Returns a dataset of (audio_array, reference_text) pairs suitable for WER evaluation.
311+
312+
Args:
313+
dataset: HuggingFace dataset with audio and text columns.
314+
audio_col: Name of the audio column. Defaults to "audio".
315+
text_col: Name of the reference text column. Defaults to "text".
316+
sample_rate: Target sample rate for audio. Defaults to 16000.
317+
max_samples: Maximum number of samples (deprecated, use limit). Defaults to None.
318+
limit: Sampling limit following Olive convention:
319+
If >= 1: use first N samples.
320+
If 0 < limit < 1: randomly sample that percentage.
321+
If 0 or None: use all samples.
322+
seed: Random seed for percentage-based sampling. Defaults to 42.
323+
**kwargs: Additional arguments.
324+
325+
"""
326+
from datasets import Audio
327+
328+
dataset = dataset.cast_column(audio_col, Audio(sampling_rate=sample_rate))
329+
330+
# Apply sampling: prefer limit over max_samples
331+
effective_limit = limit if limit is not None else (max_samples if max_samples else 0)
332+
if effective_limit and effective_limit != 0:
333+
from random import Random
334+
335+
total = len(dataset)
336+
if 0 < effective_limit < 1:
337+
n = max(1, int(total * effective_limit))
338+
rng = Random(seed)
339+
indices = sorted(rng.sample(range(total), min(n, total)))
340+
dataset = dataset.select(indices)
341+
elif effective_limit >= 1:
342+
n = min(int(effective_limit), total)
343+
dataset = dataset.select(range(n))
344+
345+
class SpeechTranscriptionDataset:
346+
"""Dataset that returns (audio_array, reference_text) pairs.
347+
348+
Note: Use batch_size=1 in dataloader config as audio samples have variable lengths.
349+
"""
350+
351+
def __init__(self, hf_dataset, audio_column, text_column):
352+
self.dataset = hf_dataset
353+
self.audio_column = audio_column
354+
self.text_column = text_column
355+
356+
def __len__(self):
357+
return len(self.dataset)
358+
359+
def __getitem__(self, idx):
360+
item = self.dataset[idx]
361+
import numpy as np
362+
363+
audio_array = np.array(item[self.audio_column]["array"], dtype=np.float32)
364+
reference_text = item[self.text_column]
365+
return audio_array, reference_text
366+
367+
@staticmethod
368+
def collate_fn(batch):
369+
"""Collate variable-length audio batches. Use with batch_size=1 or pad audio."""
370+
import numpy as np
371+
372+
# batch_size=1 is expected for speech evaluation (variable-length audio)
373+
if len(batch) == 1:
374+
audio, text = batch[0]
375+
return (np.expand_dims(audio, 0), [text])
376+
# For batch_size > 1, return as lists (no padding)
377+
audios = [item[0] for item in batch]
378+
texts = [item[1] for item in batch]
379+
return (audios, texts)
380+
381+
return SpeechTranscriptionDataset(dataset, audio_col, text_col)

olive/data/container/huggingface_container.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,7 @@ class HuggingfaceContainer(DataContainer):
3838
DataComponentType.PRE_PROCESS_DATA.value: "audio_classification_pre_process",
3939
DataComponentType.POST_PROCESS_DATA.value: "text_classification_post_process",
4040
},
41+
"speech-transcription": {
42+
DataComponentType.PRE_PROCESS_DATA.value: "speech_transcription_pre_process",
43+
},
4144
}

olive/evaluator/accuracy.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class AccuracyBase(AutoConfigClass):
2626
"recall": torchmetrics.Recall,
2727
"auroc": torchmetrics.AUROC,
2828
"perplexity": torchmetrics.text.perplexity.Perplexity,
29+
"wer": torchmetrics.text.WordErrorRate,
2930
}
3031

3132
def __init__(self, config: Optional[Union[ConfigBase, dict[str, Any]]] = None) -> None:
@@ -157,3 +158,62 @@ def measure(self, model_output, target):
157158
perplexity.update(logits, targets)
158159
result = perplexity.compute()
159160
return result.item()
161+
162+
163+
class WordErrorRate(AccuracyBase):
164+
"""Word Error Rate metric for speech/ASR evaluation.
165+
166+
Expects model_output.preds to be a list of predicted transcription strings
167+
and target to be a list of reference transcription strings.
168+
"""
169+
170+
name: Optional[str] = "wer"
171+
172+
@classmethod
173+
def _default_config(cls) -> dict[str, ConfigParam]:
174+
return {}
175+
176+
def measure(self, model_output, target):
177+
preds = model_output.preds
178+
refs = target
179+
# Ensure inputs are lists of strings
180+
if isinstance(preds, str):
181+
preds = [preds]
182+
elif not isinstance(preds, list):
183+
preds = list(preds)
184+
if isinstance(refs, str):
185+
refs = [refs]
186+
elif not isinstance(refs, list):
187+
refs = list(refs)
188+
189+
wer = torchmetrics.text.WordErrorRate(**self.config_dict)
190+
result = wer(preds, refs)
191+
return result.item()
192+
193+
194+
class RealTimeFactor(AccuracyBase):
195+
"""Real-Time Factor (RTFx) metric for speech/ASR evaluation.
196+
197+
RTFx = total_audio_duration / total_inference_time.
198+
A value > 1 means faster than real-time (e.g., RTFx=5 means 5x faster).
199+
Timing metadata is provided via model_output.logits dict.
200+
"""
201+
202+
name: Optional[str] = "rtfx"
203+
204+
@classmethod
205+
def _default_config(cls) -> dict[str, ConfigParam]:
206+
return {}
207+
208+
def measure(self, model_output, target):
209+
timing = model_output.logits
210+
if not isinstance(timing, dict) or "total_audio_duration" not in timing:
211+
raise ValueError(
212+
"RTFx metric requires timing metadata from text-based inference path. "
213+
"Ensure the metric is used with speech evaluation (WER + RTFx together)."
214+
)
215+
total_audio = timing["total_audio_duration"]
216+
total_inference = timing["total_inference_time"]
217+
if total_inference == 0:
218+
return float("inf")
219+
return round(total_audio / total_inference, 2)

olive/evaluator/metric.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class AccuracySubType(StrEnumBase):
3838
RECALL = "recall"
3939
AUROC = "auroc"
4040
PERPLEXITY = "perplexity"
41+
WER = "wer"
42+
RTFX = "rtfx"
4143

4244

4345
class LatencySubType(StrEnumBase):
@@ -206,7 +208,13 @@ def validate_sub_types(cls, v, info):
206208
# metric_config
207209
metric_config_cls = None
208210
if info.data["type"] == MetricType.ACCURACY:
209-
item["higher_is_better"] = item.get("higher_is_better", True)
211+
# Error rate metrics (WER) default to higher_is_better=False
212+
_error_rate_metrics = {"wer"}
213+
item_name = item["name"] if isinstance(item["name"], str) else item["name"].value
214+
if item_name in _error_rate_metrics:
215+
item["higher_is_better"] = item.get("higher_is_better", False)
216+
else:
217+
item["higher_is_better"] = item.get("higher_is_better", True)
210218
if info.data["backend"] == "torch_metrics":
211219
metric_config_cls = AccuracyBase.registry[item["name"]].get_config_class()
212220
elif info.data["backend"] == "huggingface_metrics":

0 commit comments

Comments
 (0)