Skip to content

Commit 96fa34c

Browse files
committed
Merge branch 'main' into multi-model
2 parents c91653f + 61b9431 commit 96fa34c

4 files changed

Lines changed: 47 additions & 122 deletions

File tree

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
# Whisper Demo App
22

3-
This app runs the Whisper model in ExecuTorch.
3+
This app demonstrates running the Whisper speech recognition model on Android using ExecuTorch.
44

5-
## Build the ExecuTorch Android library
5+
> **Note:** The ExecuTorch `AsrModule` API is not yet released. We will give a snapshot AAR soon™
66
7-
Build the [ExecuTorch Android library with QNN backend](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md).
7+
## Export Model Files
88

9-
## Export the audio processing and model .pte files
9+
Export the audio preprocessor and model `.pte` files following the instructions at:
10+
https://github.com/pytorch/executorch/tree/main/examples/models/whisper
1011

11-
There are two steps, audio processing and the Whisper model (encoder+decoder), which are both done via ExecuTorch.
12+
This app requires both a model `.pte` and a preprocessor `.pte` file.
1213

13-
1. Run `extension/audio/mel_spectrogram.py` to export `whisper_preprocess.pte`
14-
2. Run `examples/qualcomm/oss_scripts/whisper/whisper.py` to export `whisper_qnn_16a8w.pte`
15-
16-
Move these two `.pte` files along with `tokenizer.json` to `/data/local/tmp/whisper` on device.
17-
18-
## Run the app
14+
## Run the App
1915

2016
1. Open WhisperApp in Android Studio
21-
2. Copy the Android library `executorch.aar` (with audio JNI bindings) into `app/libs`
17+
2. Copy the `executorch.aar` library (with audio JNI bindings) into `app/libs`
2218
3. Build and run on device
2319

2420
## Demo
2521

26-
https://github.com/user-attachments/assets/ff8c71c5-b734-4ed4-8382-70a429830665
22+
https://github.com/user-attachments/assets/eb4c4ae6-b89f-4eb4-a291-549a42c95f54

whisper/android/WhisperApp/app/src/main/java/com/example/whisperapp/MainActivity.kt

Lines changed: 29 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,25 @@ import androidx.compose.ui.unit.dp
4444
import androidx.core.content.ContextCompat
4545
import androidx.lifecycle.ViewModelProvider
4646
import com.example.whisperapp.ui.theme.WhisperAppTheme
47-
import org.pytorch.executorch.EValue
48-
import org.pytorch.executorch.Module
49-
import org.pytorch.executorch.Tensor
5047
import org.pytorch.executorch.extension.asr.AsrCallback
5148
import org.pytorch.executorch.extension.asr.AsrModule
5249
import java.io.File
53-
import java.io.FileInputStream
5450
import java.io.FileOutputStream
5551
import java.io.IOException
5652
import java.io.OutputStream
57-
import java.nio.ByteBuffer
58-
import java.nio.ByteOrder
5953

6054
class MainActivity : ComponentActivity(), AsrCallback {
6155

6256
companion object {
6357
private const val TAG = "MainActivity"
64-
private const val RECORDING_DURATION_MS = 5000L // 5 seconds
58+
private const val RECORDING_DURATION_MS = 30000L // 30 seconds
6559
// Token lengths to remove from transcription output
6660
private const val START_TOKEN_LENGTH = 37
6761
private const val END_TOKEN_LENGTH = 13
6862
}
6963

7064
private var transcriptionOutput by mutableStateOf("")
65+
private var rawTranscriptionOutput = ""
7166
private var buttonText by mutableStateOf("Record")
7267
private var buttonEnabled by mutableStateOf(true)
7368
private var statusText by mutableStateOf("")
@@ -97,48 +92,6 @@ class MainActivity : ComponentActivity(), AsrCallback {
9792
SETTINGS
9893
}
9994

100-
@Throws(IOException::class)
101-
fun readWavPcmBytes(filePath: String): ByteArray {
102-
val wavHeaderSize = 44 // Standard header size for PCM WAV
103-
val file = File(filePath)
104-
val fis = FileInputStream(file)
105-
try {
106-
val totalSize = file.length()
107-
assert(totalSize > wavHeaderSize)
108-
val pcmSize = (totalSize - wavHeaderSize).toInt()
109-
val pcmBytes = ByteArray(pcmSize)
110-
// Skip the header
111-
val skipped = fis.skip(wavHeaderSize.toLong())
112-
if (skipped != wavHeaderSize.toLong()) throw IOException("Failed to skip WAV header")
113-
// Read PCM data
114-
val read = fis.read(pcmBytes)
115-
if (read != pcmSize) throw IOException("Failed to read all PCM data")
116-
return pcmBytes
117-
} finally {
118-
fis.close()
119-
}
120-
}
121-
122-
private fun convertPcm16ToFloat(audioBytes: ByteArray): FloatArray {
123-
val totalSamples = audioBytes.size / 2 // 2 bytes per 16-bit sample
124-
val floatSamples = FloatArray(totalSamples)
125-
126-
// Create ByteBuffer with little-endian byte order (standard for WAV)
127-
val byteBuffer = ByteBuffer.wrap(audioBytes).order(ByteOrder.LITTLE_ENDIAN)
128-
129-
for (i in 0 until totalSamples) {
130-
val sample = byteBuffer.short.toInt()
131-
// Normalize 16-bit PCM to [-1.0, 1.0]
132-
floatSamples[i] = if (sample < 0) {
133-
sample / 32768.0f
134-
} else {
135-
sample / 32767.0f
136-
}
137-
}
138-
139-
return floatSamples
140-
}
141-
14295
override fun onCreate(savedInstanceState: Bundle?) {
14396
super.onCreate(savedInstanceState)
14497

@@ -250,87 +203,57 @@ class MainActivity : ComponentActivity(), AsrCallback {
250203
return
251204
}
252205

206+
rawTranscriptionOutput = ""
253207
runOnUiThread {
254208
transcriptionOutput = ""
255-
}
256-
257-
val audioData: FloatArray
258-
val batchSize: Int
259-
val featureDim: Int
260-
val timeSteps: Int
261-
262-
if (settings.hasPreprocessor()) {
263-
// Use preprocessor to convert WAV to mel-spectrogram
264-
Log.v(TAG, "Using preprocessor: ${settings.preprocessorPath}")
265-
runOnUiThread {
266-
statusText = "Processing audio with mel-spectrogram..."
267-
}
268-
269-
val pcmBytes = readWavPcmBytes(wavFilePath)
270-
val inputFloatArray = convertPcm16ToFloat(pcmBytes)
271-
272-
val tensor1 = Tensor.fromBlob(
273-
inputFloatArray,
274-
longArrayOf(inputFloatArray.size.toLong())
275-
)
276-
val module = Module.load(settings.preprocessorPath)
277-
val eValue1 = EValue.from(tensor1)
278-
audioData = module.forward(eValue1)[0].toTensor().dataAsFloatArray
279-
280-
// result shape is [batchSize, timeSteps, featureDim]
281-
batchSize = 1
282-
featureDim = 128 // Whisper uses 128 mel bins
283-
timeSteps = audioData.size / (batchSize * featureDim)
284-
} else {
285-
// No preprocessor: use raw WAV audio directly
286-
Log.v(TAG, "No preprocessor, using raw WAV audio")
287-
runOnUiThread {
288-
statusText = "Processing raw audio..."
289-
}
290-
291-
val pcmBytes = readWavPcmBytes(wavFilePath)
292-
audioData = convertPcm16ToFloat(pcmBytes)
293-
294-
// For raw audio: batchSize=1, timeSteps=numSamples, featureDim=1
295-
batchSize = 1
296-
featureDim = 1 // Raw audio has 1 feature dimension
297-
timeSteps = audioData.size
209+
statusText = "Loading model..."
210+
buttonText = "Transcribing..."
211+
buttonEnabled = false
298212
}
299213

300214
val whisperModule = AsrModule(
301-
settings.modelPath,
302-
settings.tokenizerPath,
303-
settings.dataPath
215+
modelPath = settings.modelPath,
216+
tokenizerPath = settings.tokenizerPath,
217+
dataPath = settings.dataPath.ifBlank { null },
218+
preprocessorPath = settings.preprocessorPath.ifBlank { null }
304219
)
305220

306-
Log.v(TAG, "Starting transcribe with batchSize=$batchSize, timeSteps=$timeSteps, featureDim=$featureDim")
221+
Log.v(TAG, "Starting transcribe for: $wavFilePath")
307222
runOnUiThread {
308223
statusText = "Transcribing..."
309224
}
310-
whisperModule.transcribe(audioData, batchSize, timeSteps, featureDim, this@MainActivity)
311-
Log.v(TAG, "Finished transcribe")
225+
val startTime = System.currentTimeMillis()
226+
whisperModule.transcribe(wavFilePath, callback = this@MainActivity)
227+
val elapsedTime = System.currentTimeMillis() - startTime
228+
val elapsedSeconds = elapsedTime / 1000.0
229+
Log.v(TAG, "Finished transcribe in ${elapsedSeconds}s")
312230

313231
// Display result in Text view instead of Toast
314232
// hack to remove start and end tokens; ideally the runner should not do callback on these tokens
315233
runOnUiThread {
316234
val minLength = START_TOKEN_LENGTH + END_TOKEN_LENGTH
317-
if (transcriptionOutput.length > minLength) {
318-
val endIndex = transcriptionOutput.length - END_TOKEN_LENGTH
235+
if (rawTranscriptionOutput.length > minLength) {
236+
val endIndex = rawTranscriptionOutput.length - END_TOKEN_LENGTH
319237
if (endIndex > START_TOKEN_LENGTH) {
320-
transcriptionOutput = transcriptionOutput.substring(START_TOKEN_LENGTH, endIndex)
238+
transcriptionOutput = rawTranscriptionOutput.substring(START_TOKEN_LENGTH, endIndex)
321239
}
322240
}
323-
statusText = "Transcription complete"
241+
statusText = "Transcription complete (%.2fs)".format(elapsedSeconds)
242+
buttonText = "Record"
324243
buttonEnabled = true
325244
}
326245
}
327246

328247
override fun onToken(result: String) {
329248
Log.v(TAG, "Called callback: here's the current output")
249+
rawTranscriptionOutput += result
330250
runOnUiThread {
331-
transcriptionOutput += result
251+
// Strip start token prefix for display while transcribing
252+
if (rawTranscriptionOutput.length > START_TOKEN_LENGTH) {
253+
transcriptionOutput = rawTranscriptionOutput.substring(START_TOKEN_LENGTH)
254+
}
332255
}
333-
Log.v(TAG, transcriptionOutput)
256+
Log.v(TAG, rawTranscriptionOutput)
334257
}
335258

336259
private fun startRecording() {
@@ -358,10 +281,10 @@ class MainActivity : ComponentActivity(), AsrCallback {
358281
audioRecord?.startRecording()
359282
isRecording = true
360283

361-
buttonText = "Recording... (5s)"
284+
buttonText = "Recording... (30s)"
362285
buttonEnabled = false
363286

364-
// Schedule automatic stop after 5 seconds
287+
// Schedule automatic stop after 30 seconds
365288
stopRecordingRunnable = Runnable {
366289
stopRecording()
367290
}

whisper/android/WhisperApp/app/src/main/java/com/example/whisperapp/ModelSettingsScreen.kt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package com.example.whisperapp
22

3+
import androidx.compose.foundation.clickable
34
import androidx.compose.foundation.layout.*
45
import androidx.compose.foundation.rememberScrollState
56
import androidx.compose.foundation.verticalScroll
@@ -286,7 +287,9 @@ fun FileSelectionDialog(
286287
Column(modifier = Modifier.verticalScroll(rememberScrollState())) {
287288
if (allowNone) {
288289
Row(
289-
modifier = Modifier.fillMaxWidth(),
290+
modifier = Modifier
291+
.fillMaxWidth()
292+
.clickable { onSelect("") },
290293
verticalAlignment = Alignment.CenterVertically
291294
) {
292295
RadioButton(
@@ -301,7 +304,9 @@ fun FileSelectionDialog(
301304

302305
files.forEach { filePath ->
303306
Row(
304-
modifier = Modifier.fillMaxWidth(),
307+
modifier = Modifier
308+
.fillMaxWidth()
309+
.clickable { onSelect(filePath) },
305310
verticalAlignment = Alignment.CenterVertically
306311
) {
307312
RadioButton(

whisper/android/WhisperApp/app/src/test/java/com/example/whisperapp/ModelSettingsViewModelTest.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.example.whisperapp
22

3-
import org.junit.Assert.*
3+
import org.junit.Assert.assertEquals
4+
import org.junit.Assert.assertTrue
45
import org.junit.Before
56
import org.junit.Rule
67
import org.junit.Test

0 commit comments

Comments
 (0)