@@ -44,30 +44,25 @@ import androidx.compose.ui.unit.dp
4444import androidx.core.content.ContextCompat
4545import androidx.lifecycle.ViewModelProvider
4646import com.example.whisperapp.ui.theme.WhisperAppTheme
47- import org.pytorch.executorch.EValue
48- import org.pytorch.executorch.Module
49- import org.pytorch.executorch.Tensor
5047import org.pytorch.executorch.extension.asr.AsrCallback
5148import org.pytorch.executorch.extension.asr.AsrModule
5249import java.io.File
53- import java.io.FileInputStream
5450import java.io.FileOutputStream
5551import java.io.IOException
5652import java.io.OutputStream
57- import java.nio.ByteBuffer
58- import java.nio.ByteOrder
5953
6054class MainActivity : ComponentActivity (), AsrCallback {
6155
6256 companion object {
6357 private const val TAG = " MainActivity"
64- private const val RECORDING_DURATION_MS = 5000L // 5 seconds
58+ private const val RECORDING_DURATION_MS = 30000L // 30 seconds
6559 // Token lengths to remove from transcription output
6660 private const val START_TOKEN_LENGTH = 37
6761 private const val END_TOKEN_LENGTH = 13
6862 }
6963
7064 private var transcriptionOutput by mutableStateOf(" " )
65+ private var rawTranscriptionOutput = " "
7166 private var buttonText by mutableStateOf(" Record" )
7267 private var buttonEnabled by mutableStateOf(true )
7368 private var statusText by mutableStateOf(" " )
@@ -97,48 +92,6 @@ class MainActivity : ComponentActivity(), AsrCallback {
9792 SETTINGS
9893 }
9994
100- @Throws(IOException ::class )
101- fun readWavPcmBytes (filePath : String ): ByteArray {
102- val wavHeaderSize = 44 // Standard header size for PCM WAV
103- val file = File (filePath)
104- val fis = FileInputStream (file)
105- try {
106- val totalSize = file.length()
107- assert (totalSize > wavHeaderSize)
108- val pcmSize = (totalSize - wavHeaderSize).toInt()
109- val pcmBytes = ByteArray (pcmSize)
110- // Skip the header
111- val skipped = fis.skip(wavHeaderSize.toLong())
112- if (skipped != wavHeaderSize.toLong()) throw IOException (" Failed to skip WAV header" )
113- // Read PCM data
114- val read = fis.read(pcmBytes)
115- if (read != pcmSize) throw IOException (" Failed to read all PCM data" )
116- return pcmBytes
117- } finally {
118- fis.close()
119- }
120- }
121-
122- private fun convertPcm16ToFloat (audioBytes : ByteArray ): FloatArray {
123- val totalSamples = audioBytes.size / 2 // 2 bytes per 16-bit sample
124- val floatSamples = FloatArray (totalSamples)
125-
126- // Create ByteBuffer with little-endian byte order (standard for WAV)
127- val byteBuffer = ByteBuffer .wrap(audioBytes).order(ByteOrder .LITTLE_ENDIAN )
128-
129- for (i in 0 until totalSamples) {
130- val sample = byteBuffer.short.toInt()
131- // Normalize 16-bit PCM to [-1.0, 1.0]
132- floatSamples[i] = if (sample < 0 ) {
133- sample / 32768.0f
134- } else {
135- sample / 32767.0f
136- }
137- }
138-
139- return floatSamples
140- }
141-
14295 override fun onCreate (savedInstanceState : Bundle ? ) {
14396 super .onCreate(savedInstanceState)
14497
@@ -250,87 +203,57 @@ class MainActivity : ComponentActivity(), AsrCallback {
250203 return
251204 }
252205
206+ rawTranscriptionOutput = " "
253207 runOnUiThread {
254208 transcriptionOutput = " "
255- }
256-
257- val audioData: FloatArray
258- val batchSize: Int
259- val featureDim: Int
260- val timeSteps: Int
261-
262- if (settings.hasPreprocessor()) {
263- // Use preprocessor to convert WAV to mel-spectrogram
264- Log .v(TAG , " Using preprocessor: ${settings.preprocessorPath} " )
265- runOnUiThread {
266- statusText = " Processing audio with mel-spectrogram..."
267- }
268-
269- val pcmBytes = readWavPcmBytes(wavFilePath)
270- val inputFloatArray = convertPcm16ToFloat(pcmBytes)
271-
272- val tensor1 = Tensor .fromBlob(
273- inputFloatArray,
274- longArrayOf(inputFloatArray.size.toLong())
275- )
276- val module = Module .load(settings.preprocessorPath)
277- val eValue1 = EValue .from(tensor1)
278- audioData = module.forward(eValue1)[0 ].toTensor().dataAsFloatArray
279-
280- // result shape is [batchSize, timeSteps, featureDim]
281- batchSize = 1
282- featureDim = 128 // Whisper uses 128 mel bins
283- timeSteps = audioData.size / (batchSize * featureDim)
284- } else {
285- // No preprocessor: use raw WAV audio directly
286- Log .v(TAG , " No preprocessor, using raw WAV audio" )
287- runOnUiThread {
288- statusText = " Processing raw audio..."
289- }
290-
291- val pcmBytes = readWavPcmBytes(wavFilePath)
292- audioData = convertPcm16ToFloat(pcmBytes)
293-
294- // For raw audio: batchSize=1, timeSteps=numSamples, featureDim=1
295- batchSize = 1
296- featureDim = 1 // Raw audio has 1 feature dimension
297- timeSteps = audioData.size
209+ statusText = " Loading model..."
210+ buttonText = " Transcribing..."
211+ buttonEnabled = false
298212 }
299213
300214 val whisperModule = AsrModule (
301- settings.modelPath,
302- settings.tokenizerPath,
303- settings.dataPath
215+ modelPath = settings.modelPath,
216+ tokenizerPath = settings.tokenizerPath,
217+ dataPath = settings.dataPath.ifBlank { null },
218+ preprocessorPath = settings.preprocessorPath.ifBlank { null }
304219 )
305220
306- Log .v(TAG , " Starting transcribe with batchSize= $batchSize , timeSteps= $timeSteps , featureDim= $featureDim " )
221+ Log .v(TAG , " Starting transcribe for: $wavFilePath " )
307222 runOnUiThread {
308223 statusText = " Transcribing..."
309224 }
310- whisperModule.transcribe(audioData, batchSize, timeSteps, featureDim, this @MainActivity)
311- Log .v(TAG , " Finished transcribe" )
225+ val startTime = System .currentTimeMillis()
226+ whisperModule.transcribe(wavFilePath, callback = this @MainActivity)
227+ val elapsedTime = System .currentTimeMillis() - startTime
228+ val elapsedSeconds = elapsedTime / 1000.0
229+ Log .v(TAG , " Finished transcribe in ${elapsedSeconds} s" )
312230
313231 // Display result in Text view instead of Toast
314232 // hack to remove start and end tokens; ideally the runner should not do callback on these tokens
315233 runOnUiThread {
316234 val minLength = START_TOKEN_LENGTH + END_TOKEN_LENGTH
317- if (transcriptionOutput .length > minLength) {
318- val endIndex = transcriptionOutput .length - END_TOKEN_LENGTH
235+ if (rawTranscriptionOutput .length > minLength) {
236+ val endIndex = rawTranscriptionOutput .length - END_TOKEN_LENGTH
319237 if (endIndex > START_TOKEN_LENGTH ) {
320- transcriptionOutput = transcriptionOutput .substring(START_TOKEN_LENGTH , endIndex)
238+ transcriptionOutput = rawTranscriptionOutput .substring(START_TOKEN_LENGTH , endIndex)
321239 }
322240 }
323- statusText = " Transcription complete"
241+ statusText = " Transcription complete (%.2fs)" .format(elapsedSeconds)
242+ buttonText = " Record"
324243 buttonEnabled = true
325244 }
326245 }
327246
328247 override fun onToken (result : String ) {
329248 Log .v(TAG , " Called callback: here's the current output" )
249+ rawTranscriptionOutput + = result
330250 runOnUiThread {
331- transcriptionOutput + = result
251+ // Strip start token prefix for display while transcribing
252+ if (rawTranscriptionOutput.length > START_TOKEN_LENGTH ) {
253+ transcriptionOutput = rawTranscriptionOutput.substring(START_TOKEN_LENGTH )
254+ }
332255 }
333- Log .v(TAG , transcriptionOutput )
256+ Log .v(TAG , rawTranscriptionOutput )
334257 }
335258
336259 private fun startRecording () {
@@ -358,10 +281,10 @@ class MainActivity : ComponentActivity(), AsrCallback {
358281 audioRecord?.startRecording()
359282 isRecording = true
360283
361- buttonText = " Recording... (5s )"
284+ buttonText = " Recording... (30s )"
362285 buttonEnabled = false
363286
364- // Schedule automatic stop after 5 seconds
287+ // Schedule automatic stop after 30 seconds
365288 stopRecordingRunnable = Runnable {
366289 stopRecording()
367290 }
0 commit comments