66from __future__ import annotations
77
88import os
9+
910import numpy as np
1011import onnxruntime as ort
1112import torch
12- from transformers import WhisperProcessor
13-
1413from qai_hub_models .models ._shared .hf_whisper .app import HfWhisperApp , chunk_and_resample_audio
1514from qai_hub_models .models ._shared .hf_whisper .model import (
1615 CHUNK_LENGTH ,
1716 SAMPLE_RATE ,
1817)
18+ from transformers import WhisperProcessor
19+
1920
2021def infer_audio (app , model_id , audio_file , save_data ):
2122 audio_dict = np .load (audio_file , allow_pickle = True ).item ()
@@ -25,7 +26,7 @@ def infer_audio(app, model_id, audio_file, save_data):
2526 audio_name = os .path .splitext (os .path .basename (audio_file ))[0 ] if save_data else None
2627
2728 processor = WhisperProcessor .from_pretrained (model_id )
28- reference = processor .tokenizer ._normalize (audio_dict [' text' ])
29+ reference = processor .tokenizer ._normalize (audio_dict [" text" ])
2930 print ("Reference: " , reference )
3031
3132 # Perform transcription
@@ -47,17 +48,17 @@ def __init__(
4748 ):
4849 super ().__init__ (None , None , hf_model_id , sample_rate , max_audio_seconds )
4950 options = ort .SessionOptions ()
50-
51+
5152 self .encoder = ort .InferenceSession (encoder ,
5253 sess_options = options ,
5354 providers = [execution_provider ],
5455 provider_options = [provider_options ])
55-
56+
5657 self .decoder = ort .InferenceSession (decoder ,
5758 sess_options = options ,
5859 providers = [execution_provider ],
5960 provider_options = [provider_options ])
60-
61+
6162 def transcribe_tokens (
6263 self , audio , sample_rate , audio_name , save_data = False
6364 ) -> list [int ]:
@@ -71,13 +72,13 @@ def transcribe_tokens(
7172 for chunk_tokens in out_chunked_tokens :
7273 out_tokens .extend (chunk_tokens )
7374 return out_tokens
74-
75+
7576 def transcribe (
7677 self , audio , sample_rate , audio_name , save_data = False
7778 ) -> str :
7879 tokens = self .transcribe_tokens (audio , sample_rate , audio_name , save_data )
7980 return self .tokenizer .decode (tokens , skip_special_tokens = True ).strip ()
80-
81+
8182 def _transcribe_single_chunk (self , audio : np .ndarray , audio_name = None , chunk_number = None , save_data = False ) -> list [int ]:
8283 # feature
8384 input_features = self .feature_extractor (
@@ -87,7 +88,7 @@ def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_n
8788 # encoder
8889 output_names_encoder = [output .name for output in self .encoder .get_outputs ()]
8990 # kv_cache_cross = self.encoder(input_features)
90- input_features_feed = {' input_features' : input_features }
91+ input_features_feed = {" input_features" : input_features }
9192
9293 if (save_data ):
9394 input_features_save_path = os .path .join (save_data , audio_name , f"{ chunk_number } _input_features.npy" )
@@ -170,7 +171,7 @@ def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_n
170171 # print("decoder_input: ", decoder_input)
171172 input_names_decoder = [input .name for input in self .decoder .get_inputs ()]
172173 output_names_decoder = [output .name for output in self .decoder .get_outputs ()]
173-
174+
174175 # decoder_input_feed = dict(zip(input_names_decoder, decoder_input))
175176 decoder_input_feed = {name : tensor .numpy () if isinstance (tensor , torch .Tensor ) else tensor for name , tensor in zip (input_names_decoder , decoder_input )}
176177
@@ -179,7 +180,7 @@ def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_n
179180 os .makedirs (os .path .dirname (decoder_input_save_path ), exist_ok = True )
180181 np .save (decoder_input_save_path , decoder_input_feed )
181182
182- decoder_output_numpy = self .decoder .run (output_names_decoder , decoder_input_feed )
183+ decoder_output_numpy = self .decoder .run (output_names_decoder , decoder_input_feed )
183184 decoder_output = [torch .from_numpy (arr ) for arr in decoder_output_numpy ]
184185 # decoder_output = self.decoder(*decoder_input)
185186 if isinstance (decoder_output , tuple ) and len (decoder_output ) == 2 :
@@ -206,4 +207,3 @@ def _transcribe_single_chunk(self, audio: np.ndarray, audio_name = None, chunk_n
206207 position_ids += 1
207208
208209 return output_ids [0 ].tolist ()
209-
0 commit comments