@@ -44,18 +44,12 @@ import androidx.compose.ui.unit.dp
4444import androidx.core.content.ContextCompat
4545import androidx.lifecycle.ViewModelProvider
4646import com.example.whisperapp.ui.theme.WhisperAppTheme
47- import org.pytorch.executorch.EValue
48- import org.pytorch.executorch.Module
49- import org.pytorch.executorch.Tensor
5047import org.pytorch.executorch.extension.asr.AsrCallback
5148import org.pytorch.executorch.extension.asr.AsrModule
5249import java.io.File
53- import java.io.FileInputStream
5450import java.io.FileOutputStream
5551import java.io.IOException
5652import java.io.OutputStream
57- import java.nio.ByteBuffer
58- import java.nio.ByteOrder
5953
6054class MainActivity : ComponentActivity (), AsrCallback {
6155
@@ -97,48 +91,6 @@ class MainActivity : ComponentActivity(), AsrCallback {
9791 SETTINGS
9892 }
9993
100- @Throws(IOException ::class )
101- fun readWavPcmBytes (filePath : String ): ByteArray {
102- val wavHeaderSize = 44 // Standard header size for PCM WAV
103- val file = File (filePath)
104- val fis = FileInputStream (file)
105- try {
106- val totalSize = file.length()
107- assert (totalSize > wavHeaderSize)
108- val pcmSize = (totalSize - wavHeaderSize).toInt()
109- val pcmBytes = ByteArray (pcmSize)
110- // Skip the header
111- val skipped = fis.skip(wavHeaderSize.toLong())
112- if (skipped != wavHeaderSize.toLong()) throw IOException (" Failed to skip WAV header" )
113- // Read PCM data
114- val read = fis.read(pcmBytes)
115- if (read != pcmSize) throw IOException (" Failed to read all PCM data" )
116- return pcmBytes
117- } finally {
118- fis.close()
119- }
120- }
121-
122- private fun convertPcm16ToFloat (audioBytes : ByteArray ): FloatArray {
123- val totalSamples = audioBytes.size / 2 // 2 bytes per 16-bit sample
124- val floatSamples = FloatArray (totalSamples)
125-
126- // Create ByteBuffer with little-endian byte order (standard for WAV)
127- val byteBuffer = ByteBuffer .wrap(audioBytes).order(ByteOrder .LITTLE_ENDIAN )
128-
129- for (i in 0 until totalSamples) {
130- val sample = byteBuffer.short.toInt()
131- // Normalize 16-bit PCM to [-1.0, 1.0]
132- floatSamples[i] = if (sample < 0 ) {
133- sample / 32768.0f
134- } else {
135- sample / 32767.0f
136- }
137- }
138-
139- return floatSamples
140- }
141-
14294 override fun onCreate (savedInstanceState : Bundle ? ) {
14395 super .onCreate(savedInstanceState)
14496
@@ -252,62 +204,21 @@ class MainActivity : ComponentActivity(), AsrCallback {
252204
253205 runOnUiThread {
254206 transcriptionOutput = " "
255- }
256-
257- val audioData: FloatArray
258- val batchSize: Int
259- val featureDim: Int
260- val timeSteps: Int
261-
262- if (settings.hasPreprocessor()) {
263- // Use preprocessor to convert WAV to mel-spectrogram
264- Log .v(TAG , " Using preprocessor: ${settings.preprocessorPath} " )
265- runOnUiThread {
266- statusText = " Processing audio with mel-spectrogram..."
267- }
268-
269- val pcmBytes = readWavPcmBytes(wavFilePath)
270- val inputFloatArray = convertPcm16ToFloat(pcmBytes)
271-
272- val tensor1 = Tensor .fromBlob(
273- inputFloatArray,
274- longArrayOf(inputFloatArray.size.toLong())
275- )
276- val module = Module .load(settings.preprocessorPath)
277- val eValue1 = EValue .from(tensor1)
278- audioData = module.forward(eValue1)[0 ].toTensor().dataAsFloatArray
279-
280- // result shape is [batchSize, timeSteps, featureDim]
281- batchSize = 1
282- featureDim = 128 // Whisper uses 128 mel bins
283- timeSteps = audioData.size / (batchSize * featureDim)
284- } else {
285- // No preprocessor: use raw WAV audio directly
286- Log .v(TAG , " No preprocessor, using raw WAV audio" )
287- runOnUiThread {
288- statusText = " Processing raw audio..."
289- }
290-
291- val pcmBytes = readWavPcmBytes(wavFilePath)
292- audioData = convertPcm16ToFloat(pcmBytes)
293-
294- // For raw audio: batchSize=1, timeSteps=numSamples, featureDim=1
295- batchSize = 1
296- featureDim = 1 // Raw audio has 1 feature dimension
297- timeSteps = audioData.size
207+ statusText = " Loading model..."
298208 }
299209
300210 val whisperModule = AsrModule (
301- settings.modelPath,
302- settings.tokenizerPath,
303- settings.dataPath
211+ modelPath = settings.modelPath,
212+ tokenizerPath = settings.tokenizerPath,
213+ dataPath = settings.dataPath.ifBlank { null },
214+ preprocessorPath = settings.preprocessorPath.ifBlank { null }
304215 )
305216
306- Log .v(TAG , " Starting transcribe with batchSize= $batchSize , timeSteps= $timeSteps , featureDim= $featureDim " )
217+ Log .v(TAG , " Starting transcribe for: $wavFilePath " )
307218 runOnUiThread {
308219 statusText = " Transcribing..."
309220 }
310- whisperModule.transcribe(audioData, batchSize, timeSteps, featureDim , this @MainActivity)
221+ whisperModule.transcribe(wavFilePath , this @MainActivity)
311222 Log .v(TAG , " Finished transcribe" )
312223
313224 // Display result in Text view instead of Toast
0 commit comments