Skip to content

Commit e47e0ce

Browse files
authored
Merge branch 'main' into bhamehta/fix-telemetry-deadlock
2 parents 74f0b3a + 93b0ed4 commit e47e0ce

12 files changed

Lines changed: 1563 additions & 5 deletions

File tree

docs/source/how-to/configure-workflows/metrics-configuration.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,28 @@ If you have multiple metrics to evaluate, you can configure them in the followin
128128
```{Note}
129129
If you have more than one metric, you need to specify `priority: {RANK}`, which Olive will use to determine the best model.
130130
```
131+
132+
## Speech Evaluation Metrics (WER and RTFx)
133+
134+
Olive supports Word Error Rate (WER) and Real-Time Factor (RTFx) as built-in accuracy sub-types for evaluating speech/ASR models.
135+
136+
### Using WER with the accuracy metric type
137+
138+
WER can be used as an accuracy sub-type when your data pipeline returns text predictions and references:
139+
140+
```json
141+
{
142+
"name": "speech_accuracy",
143+
"type": "accuracy",
144+
"data_config": "speech_data_config",
145+
"sub_types": [
146+
{"name": "wer", "priority": 1, "higher_is_better": false},
147+
{"name": "rtfx", "priority": 2, "higher_is_better": true}
148+
]
149+
}
150+
```
151+
152+
```{Note}
153+
- `wer` (Word Error Rate): Measures transcription errors. Lower is better (defaults to `higher_is_better: false`).
154+
- `rtfx` (Real-Time Factor): Ratio of audio duration to inference time. Higher means faster (defaults to `higher_is_better: true`).
155+
```

olive/data/component/pre_process_data.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,91 @@ def _tokenizer_and_align_labels(examples):
291291

292292
tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, max_samples, **kwargs)
293293
return ClassificationDataset(tokenized_datasets, label_col="label", max_samples=max_samples)
294+
295+
296+
@Registry.register_pre_process()
297+
def speech_transcription_pre_process(
298+
dataset,
299+
audio_col: str = "audio",
300+
text_col: str = "text",
301+
sample_rate: int = 16000,
302+
max_samples: Optional[int] = None,
303+
limit: Optional[float] = None,
304+
seed: int = 42,
305+
**kwargs,
306+
):
307+
"""Pre-process data for speech transcription (ASR) evaluation.
308+
309+
Loads audio arrays and reference transcription text from a HuggingFace dataset.
310+
Returns a dataset of (audio_array, reference_text) pairs suitable for WER evaluation.
311+
312+
Args:
313+
dataset: HuggingFace dataset with audio and text columns.
314+
audio_col: Name of the audio column. Defaults to "audio".
315+
text_col: Name of the reference text column. Defaults to "text".
316+
sample_rate: Target sample rate for audio. Defaults to 16000.
317+
max_samples: Maximum number of samples (deprecated, use limit). Defaults to None.
318+
limit: Sampling limit following Olive convention:
319+
If >= 1: use first N samples.
320+
If 0 < limit < 1: randomly sample that percentage.
321+
If 0 or None: use all samples.
322+
seed: Random seed for percentage-based sampling. Defaults to 42.
323+
**kwargs: Additional arguments.
324+
325+
"""
326+
from datasets import Audio
327+
328+
dataset = dataset.cast_column(audio_col, Audio(sampling_rate=sample_rate))
329+
330+
# Apply sampling: prefer limit over max_samples
331+
effective_limit = limit if limit is not None else (max_samples if max_samples else 0)
332+
if effective_limit and effective_limit != 0:
333+
from random import Random
334+
335+
total = len(dataset)
336+
if 0 < effective_limit < 1:
337+
n = max(1, int(total * effective_limit))
338+
rng = Random(seed)
339+
indices = sorted(rng.sample(range(total), min(n, total)))
340+
dataset = dataset.select(indices)
341+
elif effective_limit >= 1:
342+
n = min(int(effective_limit), total)
343+
dataset = dataset.select(range(n))
344+
345+
class SpeechTranscriptionDataset:
346+
"""Dataset that returns (audio_array, reference_text) pairs.
347+
348+
Note: Use batch_size=1 in dataloader config as audio samples have variable lengths.
349+
"""
350+
351+
def __init__(self, hf_dataset, audio_column, text_column):
352+
self.dataset = hf_dataset
353+
self.audio_column = audio_column
354+
self.text_column = text_column
355+
356+
def __len__(self):
357+
return len(self.dataset)
358+
359+
def __getitem__(self, idx):
360+
item = self.dataset[idx]
361+
import numpy as np
362+
363+
audio_array = np.array(item[self.audio_column]["array"], dtype=np.float32)
364+
reference_text = item[self.text_column]
365+
return audio_array, reference_text
366+
367+
@staticmethod
368+
def collate_fn(batch):
369+
"""Collate variable-length audio batches. Use with batch_size=1 or pad audio."""
370+
import numpy as np
371+
372+
# batch_size=1 is expected for speech evaluation (variable-length audio)
373+
if len(batch) == 1:
374+
audio, text = batch[0]
375+
return (np.expand_dims(audio, 0), [text])
376+
# For batch_size > 1, return as lists (no padding)
377+
audios = [item[0] for item in batch]
378+
texts = [item[1] for item in batch]
379+
return (audios, texts)
380+
381+
return SpeechTranscriptionDataset(dataset, audio_col, text_col)

olive/data/container/huggingface_container.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,7 @@ class HuggingfaceContainer(DataContainer):
3838
DataComponentType.PRE_PROCESS_DATA.value: "audio_classification_pre_process",
3939
DataComponentType.POST_PROCESS_DATA.value: "text_classification_post_process",
4040
},
41+
"speech-transcription": {
42+
DataComponentType.PRE_PROCESS_DATA.value: "speech_transcription_pre_process",
43+
},
4144
}

olive/evaluator/accuracy.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class AccuracyBase(AutoConfigClass):
2626
"recall": torchmetrics.Recall,
2727
"auroc": torchmetrics.AUROC,
2828
"perplexity": torchmetrics.text.perplexity.Perplexity,
29+
"wer": torchmetrics.text.WordErrorRate,
2930
}
3031

3132
def __init__(self, config: Optional[Union[ConfigBase, dict[str, Any]]] = None) -> None:
@@ -157,3 +158,62 @@ def measure(self, model_output, target):
157158
perplexity.update(logits, targets)
158159
result = perplexity.compute()
159160
return result.item()
161+
162+
163+
class WordErrorRate(AccuracyBase):
164+
"""Word Error Rate metric for speech/ASR evaluation.
165+
166+
Expects model_output.preds to be a list of predicted transcription strings
167+
and target to be a list of reference transcription strings.
168+
"""
169+
170+
name: Optional[str] = "wer"
171+
172+
@classmethod
173+
def _default_config(cls) -> dict[str, ConfigParam]:
174+
return {}
175+
176+
def measure(self, model_output, target):
177+
preds = model_output.preds
178+
refs = target
179+
# Ensure inputs are lists of strings
180+
if isinstance(preds, str):
181+
preds = [preds]
182+
elif not isinstance(preds, list):
183+
preds = list(preds)
184+
if isinstance(refs, str):
185+
refs = [refs]
186+
elif not isinstance(refs, list):
187+
refs = list(refs)
188+
189+
wer = torchmetrics.text.WordErrorRate(**self.config_dict)
190+
result = wer(preds, refs)
191+
return result.item()
192+
193+
194+
class RealTimeFactor(AccuracyBase):
195+
"""Real-Time Factor (RTFx) metric for speech/ASR evaluation.
196+
197+
RTFx = total_audio_duration / total_inference_time.
198+
A value > 1 means faster than real-time (e.g., RTFx=5 means 5x faster).
199+
Timing metadata is provided via model_output.logits dict.
200+
"""
201+
202+
name: Optional[str] = "rtfx"
203+
204+
@classmethod
205+
def _default_config(cls) -> dict[str, ConfigParam]:
206+
return {}
207+
208+
def measure(self, model_output, target):
209+
timing = model_output.logits
210+
if not isinstance(timing, dict) or "total_audio_duration" not in timing:
211+
raise ValueError(
212+
"RTFx metric requires timing metadata from text-based inference path. "
213+
"Ensure the metric is used with speech evaluation (WER + RTFx together)."
214+
)
215+
total_audio = timing["total_audio_duration"]
216+
total_inference = timing["total_inference_time"]
217+
if total_inference == 0:
218+
return float("inf")
219+
return round(total_audio / total_inference, 2)

olive/evaluator/lmeval_ort.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,12 @@ def __init__(
509509
self.max_length = max_length
510510
else:
511511
self.max_length = genai_config["search"]["max_length"]
512-
self._eot_token_id = genai_config["model"]["eos_token_id"]
512+
eot = genai_config["model"]["eos_token_id"]
513+
# eos_token_id can be a list (e.g. [1, 106] for Gemma4) or a scalar.
514+
# Store all EOS IDs for generate_until stop detection,
515+
# and first/scalar for loglikelihood (TemplateLM.eot_token_id expects int).
516+
self._eos_token_ids = list(eot) if isinstance(eot, list) else [eot]
517+
self._eot_token_id = self._eos_token_ids[0]
513518
self.params = og.GeneratorParams(self.model)
514519
self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False)
515520

@@ -575,3 +580,70 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor
575580

576581
def complete(self):
577582
pass
583+
584+
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
585+
"""Generate text until a stop sequence is reached.
586+
587+
Used by benchmarks like MMLU Pro (CoT variant) that score by generating
588+
chain-of-thought text and extracting the answer with a regex filter.
589+
"""
590+
results = []
591+
for request in tqdm(requests, disable=disable_tqdm, desc="Running generate_until requests"):
592+
context = request.args[0]
593+
gen_kwargs = request.args[1]
594+
595+
until = gen_kwargs.get("until", [])
596+
max_gen_toks = gen_kwargs.get("max_gen_toks", 256)
597+
if isinstance(until, str):
598+
until = [until]
599+
600+
input_ids = self.tok_encode(context)
601+
max_new_tokens = min(max_gen_toks, self.max_length - len(input_ids))
602+
if max_new_tokens <= 0:
603+
results.append("")
604+
continue
605+
606+
params = og.GeneratorParams(self.model)
607+
params.set_search_options(
608+
max_length=len(input_ids) + max_new_tokens,
609+
past_present_share_buffer=False,
610+
batch_size=1,
611+
)
612+
if gen_kwargs.get("temperature", 0.0) == 0.0:
613+
params.set_search_options(do_sample=False)
614+
else:
615+
params.set_search_options(
616+
do_sample=True,
617+
temperature=gen_kwargs["temperature"],
618+
)
619+
620+
generator = og.Generator(self.model, params)
621+
generator.append_tokens([input_ids])
622+
623+
eos_ids = self._eos_token_ids
624+
625+
generated_ids = []
626+
# Decode periodically to check for stop sequences
627+
decode_interval = 16
628+
while not generator.is_done():
629+
generator.generate_next_token()
630+
token_id = generator.get_next_tokens()[0]
631+
generated_ids.append(token_id)
632+
if token_id in eos_ids:
633+
break
634+
# Check stop sequences periodically by decoding
635+
if until and len(generated_ids) % decode_interval == 0:
636+
partial_text = self.tokenizer.decode(generated_ids)
637+
if any(stop_seq in partial_text for stop_seq in until):
638+
break
639+
640+
generated_text = self.tokenizer.decode(generated_ids)
641+
642+
# Truncate at the first stop sequence
643+
for stop_seq in until:
644+
idx = generated_text.find(stop_seq)
645+
if idx != -1:
646+
generated_text = generated_text[:idx]
647+
648+
results.append(generated_text)
649+
return results

olive/evaluator/metric.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class AccuracySubType(StrEnumBase):
3838
RECALL = "recall"
3939
AUROC = "auroc"
4040
PERPLEXITY = "perplexity"
41+
WER = "wer"
42+
RTFX = "rtfx"
4143

4244

4345
class LatencySubType(StrEnumBase):
@@ -206,7 +208,13 @@ def validate_sub_types(cls, v, info):
206208
# metric_config
207209
metric_config_cls = None
208210
if info.data["type"] == MetricType.ACCURACY:
209-
item["higher_is_better"] = item.get("higher_is_better", True)
211+
# Error rate metrics (WER) default to higher_is_better=False
212+
_error_rate_metrics = {"wer"}
213+
item_name = item["name"] if isinstance(item["name"], str) else item["name"].value
214+
if item_name in _error_rate_metrics:
215+
item["higher_is_better"] = item.get("higher_is_better", False)
216+
else:
217+
item["higher_is_better"] = item.get("higher_is_better", True)
210218
if info.data["backend"] == "torch_metrics":
211219
metric_config_cls = AccuracyBase.registry[item["name"]].get_config_class()
212220
elif info.data["backend"] == "huggingface_metrics":

0 commit comments

Comments
 (0)