Skip to content

Commit bd72224

Browse files
committed
Applied lintrunner changes
1 parent bb5b208 commit bd72224

6 files changed

Lines changed: 51 additions & 46 deletions

File tree

openai-whisper-large-v3-turbo/olive/app.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,17 @@
66
from __future__ import annotations
77

88
import os
9+
910
import numpy as np
1011
import onnxruntime as ort
1112
import torch
12-
from transformers import WhisperProcessor
13-
1413
from qai_hub_models.models._shared.hf_whisper.app import HfWhisperApp, chunk_and_resample_audio
1514
from qai_hub_models.models._shared.hf_whisper.model import (
1615
CHUNK_LENGTH,
1716
SAMPLE_RATE,
1817
)
18+
from transformers import WhisperProcessor
19+
1920

2021
def infer_audio(app, model_id, audio_file, save_data):
2122
audio_dict = np.load(audio_file, allow_pickle=True).item()
@@ -25,7 +26,7 @@ def infer_audio(app, model_id, audio_file, save_data):
2526
audio_name = os.path.splitext(os.path.basename(audio_file))[0] if save_data else None
2627

2728
processor = WhisperProcessor.from_pretrained(model_id)
28-
reference = processor.tokenizer._normalize(audio_dict['text'])
29+
reference = processor.tokenizer._normalize(audio_dict["text"])
2930
print("Reference: ", reference)
3031

3132
# Perform transcription
@@ -47,17 +48,17 @@ def __init__(
4748
):
4849
super().__init__(None, None, hf_model_id, sample_rate, max_audio_seconds)
4950
options = ort.SessionOptions()
50-
51+
5152
self.encoder = ort.InferenceSession(encoder,
5253
sess_options=options,
5354
providers=[execution_provider],
5455
provider_options=[provider_options])
55-
56+
5657
self.decoder = ort.InferenceSession(decoder,
5758
sess_options=options,
5859
providers=[execution_provider],
5960
provider_options=[provider_options])
60-
61+
6162
def transcribe_tokens(
6263
self, audio, sample_rate, audio_name, save_data = False
6364
) -> list[int]:
@@ -71,13 +72,13 @@ def transcribe_tokens(
7172
for chunk_tokens in out_chunked_tokens:
7273
out_tokens.extend(chunk_tokens)
7374
return out_tokens
74-
75+
7576
def transcribe(
7677
self, audio, sample_rate, audio_name, save_data = False
7778
) -> str:
7879
tokens = self.transcribe_tokens(audio, sample_rate, audio_name, save_data)
7980
return self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
80-
81+
8182
def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_number = None, save_data = False) -> list[int]:
8283
# feature
8384
input_features = self.feature_extractor(
@@ -87,7 +88,7 @@ def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_n
8788
# encoder
8889
output_names_encoder = [output.name for output in self.encoder.get_outputs()]
8990
# kv_cache_cross = self.encoder(input_features)
90-
input_features_feed = {'input_features': input_features}
91+
input_features_feed = {"input_features": input_features}
9192

9293
if(save_data):
9394
input_features_save_path = os.path.join(save_data, audio_name, f"{chunk_number}_input_features.npy")
@@ -170,7 +171,7 @@ def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_n
170171
# print("decoder_input: ", decoder_input)
171172
input_names_decoder = [input.name for input in self.decoder.get_inputs()]
172173
output_names_decoder = [output.name for output in self.decoder.get_outputs()]
173-
174+
174175
# decoder_input_feed = dict(zip(input_names_decoder, decoder_input))
175176
decoder_input_feed = {name: tensor.numpy() if isinstance(tensor, torch.Tensor) else tensor for name, tensor in zip(input_names_decoder, decoder_input)}
176177

@@ -179,7 +180,7 @@ def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_n
179180
os.makedirs(os.path.dirname(decoder_input_save_path), exist_ok=True)
180181
np.save(decoder_input_save_path, decoder_input_feed)
181182

182-
decoder_output_numpy = self.decoder.run(output_names_decoder, decoder_input_feed)
183+
decoder_output_numpy = self.decoder.run(output_names_decoder, decoder_input_feed)
183184
decoder_output = [torch.from_numpy(arr) for arr in decoder_output_numpy]
184185
# decoder_output = self.decoder(*decoder_input)
185186
if isinstance(decoder_output, tuple) and len(decoder_output) == 2:
@@ -206,4 +207,3 @@ def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_n
206207
position_ids += 1
207208

208209
return output_ids[0].tolist()
209-

openai-whisper-large-v3-turbo/olive/demo.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
# ---------------------------------------------------------------------
55

6-
import os
76
import argparse
7+
import os
8+
89
from app import HfWhisperAppWithSave, infer_audio
910

11+
1012
def main():
1113
parser = argparse.ArgumentParser(description="Demo")
1214
parser.add_argument(
@@ -55,25 +57,27 @@ def main():
5557
decoder_path = args.decoder
5658

5759
provider_options = {}
58-
if(args.execution_provider == "QNNExectionProvider"):
59-
provider_options = {"backend_path": "QnnHtp.dll",
60-
"htp_performance_mode": "sustained_high_performance",
61-
"htp_graph_finalization_optimization_mode": "3",
62-
"offload_graph_io_quantization": "0",
63-
}
64-
60+
if args.execution_provider == "QNNExectionProvider":
61+
provider_options = {
62+
"backend_path": "QnnHtp.dll",
63+
"htp_performance_mode": "sustained_high_performance",
64+
"htp_graph_finalization_optimization_mode": "3",
65+
"offload_graph_io_quantization": "0",
66+
}
67+
6568
app = HfWhisperAppWithSave(encoder_path, decoder_path, args.model_id, args.execution_provider, provider_options)
6669

67-
if os.path.isdir(args.audio_path):
70+
if os.path.isdir(args.audio_path):
6871
for i, item in enumerate(os.listdir(args.audio_path)):
69-
if(args.save_data and i == args.num_data):
72+
if args.save_data and i == args.num_data:
7073
break
71-
74+
7275
full_path = os.path.join(args.audio_path, item)
7376
infer_audio(app, args.model_id, full_path, args.save_data)
74-
77+
7578
else:
7679
infer_audio(app, args.model_id, args.audio_path, args.save_data)
7780

81+
7882
if __name__ == "__main__":
79-
main()
83+
main()

openai-whisper-large-v3-turbo/olive/download_librispeech_asr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import os
21
import argparse
2+
import os
3+
34
import numpy as np
4-
from itertools import islice
55
from datasets import load_dataset
66

7+
78
def download_librispeech_asr(save_dir):
89
# Create save_dir if it doesn't exist
910
save_dir = os.path.join(save_dir, "librispeech_asr_clean_test")

openai-whisper-large-v3-turbo/olive/evaluate_whisper.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
# ---------------------------------------------------------------------
55

6-
# from qai_hub_models.models._shared.hf_whisper.demo import hf_whisper_demo # noqa
7-
# from qai_hub_models.models.whisper_small.model import WhisperSmall # noqa
6+
# from qai_hub_models.models._shared.hf_whisper.demo import hf_whisper_demo
7+
# from qai_hub_models.models.whisper_small.model import WhisperSmall
88

99
import argparse
1010

11-
from evaluate import load
11+
from app import HfWhisperAppWithSave
1212
from datasets import load_dataset
13+
from evaluate import load
1314
from transformers import WhisperProcessor
1415

15-
from app import HfWhisperAppWithSave
1616

1717
def main():
1818
parser = argparse.ArgumentParser(description="Evaluate Whisper")
@@ -50,7 +50,7 @@ def main():
5050
"htp_graph_finalization_optimization_mode": "3",
5151
"offload_graph_io_quantization": "0",
5252
}
53-
53+
5454
processor = WhisperProcessor.from_pretrained(args.model_id)
5555
app = HfWhisperAppWithSave(encoder_path, decoder_path, args.model_id, args.execution_provider, provider_options)
5656

@@ -68,7 +68,7 @@ def main():
6868
transcription = app.transcribe(audio, audio_sample_rate, None, None)
6969
prediction = processor.tokenizer._normalize(transcription)
7070

71-
reference = processor.tokenizer._normalize(item['text'])
71+
reference = processor.tokenizer._normalize(item["text"])
7272
references.append(reference)
7373
predictions.append(prediction)
7474
print("Reference: ", reference)
@@ -78,4 +78,4 @@ def main():
7878
print("WER:", 100 * wer.compute(references=references, predictions=predictions))
7979

8080
if __name__ == "__main__":
81-
main()
81+
main()

openai-whisper-large-v3-turbo/olive/whisper_decoder_load.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
import glob
12
import os
3+
24
import numpy as np
3-
import os
4-
import glob
5+
from qai_hub_models.utils.input_spec import make_torch_inputs
56

67
from olive.data.registry import Registry
7-
from qai_hub_models.utils.input_spec import make_torch_inputs
88

99

1010
def model_loader(model_name):
1111
if(model_name == "openai/whisper-large-v3-turbo"):
12-
from qai_hub_models.models.whisper_large_v3_turbo import Model
12+
from qai_hub_models.models.whisper_large_v3_turbo import Model
1313
model = Model.from_pretrained()
1414
component = model.components["HfWhisperDecoder"]
1515
return component
@@ -26,11 +26,11 @@ def generate_dummy_inputs(model=None):
2626

2727
class DecoderBaseDataLoader:
2828
def __init__(self, data_path):
29-
self.data_files = glob.glob(os.path.join(data_path, '**', '*_decoder_input.npy'), recursive=True)
29+
self.data_files = glob.glob(os.path.join(data_path, "**", "*_decoder_input.npy"), recursive=True)
3030

3131
def __len__(self):
3232
return len(self.data_files)
33-
33+
3434
def __getitem__(self, idx):
3535
return np.load(self.data_files[idx], allow_pickle=True).item()
3636

openai-whisper-large-v3-turbo/olive/whisper_encoder_load.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
import glob
12
import os
3+
24
import numpy as np
3-
import os
4-
import glob
5+
from qai_hub_models.utils.input_spec import make_torch_inputs
56

67
from olive.data.registry import Registry
7-
from qai_hub_models.utils.input_spec import make_torch_inputs
88

99

1010
def model_loader(model_name):
@@ -25,14 +25,14 @@ def generate_dummy_inputs(model=None):
2525

2626
class EncoderBaseDataLoader:
2727
def __init__(self, data_path):
28-
self.data_files = glob.glob(os.path.join(data_path, '**', '*_input_features.npy'), recursive=True)
28+
self.data_files = glob.glob(os.path.join(data_path, "**", "*_input_features.npy"), recursive=True)
2929

3030
def __len__(self):
3131
return len(self.data_files)
32-
32+
3333
def __getitem__(self, idx):
3434
return np.load(self.data_files[idx], allow_pickle=True).item()
35-
35+
3636
@Registry.register_dataloader()
3737
def encoder_data_loader(dataset, data_path):
3838
return EncoderBaseDataLoader(data_path)

0 commit comments

Comments
 (0)