Skip to content

Commit feed004

Browse files
authored
AsrModule demo (#194)
1 parent 9a8b021 commit feed004

2 files changed

Lines changed: 16 additions & 109 deletions

File tree

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
# Whisper Demo App
22

3-
This app runs the Whisper model in ExecuTorch.
3+
This app demonstrates running the Whisper speech recognition model on Android using ExecuTorch.
44

5-
## Build the ExecuTorch Android library
5+
> **Note:** The ExecuTorch `AsrModule` API is not yet released. We will give a snapshot AAR soon™
66
7-
Build the [ExecuTorch Android library with QNN backend](https://github.com/pytorch/executorch/blob/main/examples/demo-apps/android/LlamaDemo/docs/delegates/qualcomm_README.md).
7+
## Export Model Files
88

9-
## Export the audio processing and model .pte files
9+
Export the audio preprocessor and model `.pte` files following the instructions at:
10+
https://github.com/pytorch/executorch/tree/main/examples/models/whisper
1011

11-
There are two steps, audio processing and the Whisper model (encoder+decoder), which are both done via ExecuTorch.
12+
This app requires both a model `.pte` and a preprocessor `.pte` file.
1213

13-
1. Run `extension/audio/mel_spectrogram.py` to export `whisper_preprocess.pte`
14-
2. Run `examples/qualcomm/oss_scripts/whisper/whisper.py` to export `whisper_qnn_16a8w.pte`
15-
16-
Move these two `.pte` files along with `tokenizer.json` to `/data/local/tmp/whisper` on device.
17-
18-
## Run the app
14+
## Run the App
1915

2016
1. Open WhisperApp in Android Studio
21-
2. Copy the Android library `executorch.aar` (with audio JNI bindings) into `app/libs`
17+
2. Copy the `executorch.aar` library (with audio JNI bindings) into `app/libs`
2218
3. Build and run on device
2319

2420
## Demo
2521

26-
https://github.com/user-attachments/assets/ff8c71c5-b734-4ed4-8382-70a429830665
22+
https://github.com/user-attachments/assets/eb4c4ae6-b89f-4eb4-a291-549a42c95f54

whisper/android/WhisperApp/app/src/main/java/com/example/whisperapp/MainActivity.kt

Lines changed: 7 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,12 @@ import androidx.compose.ui.unit.dp
4444
import androidx.core.content.ContextCompat
4545
import androidx.lifecycle.ViewModelProvider
4646
import com.example.whisperapp.ui.theme.WhisperAppTheme
47-
import org.pytorch.executorch.EValue
48-
import org.pytorch.executorch.Module
49-
import org.pytorch.executorch.Tensor
5047
import org.pytorch.executorch.extension.asr.AsrCallback
5148
import org.pytorch.executorch.extension.asr.AsrModule
5249
import java.io.File
53-
import java.io.FileInputStream
5450
import java.io.FileOutputStream
5551
import java.io.IOException
5652
import java.io.OutputStream
57-
import java.nio.ByteBuffer
58-
import java.nio.ByteOrder
5953

6054
class MainActivity : ComponentActivity(), AsrCallback {
6155

@@ -97,48 +91,6 @@ class MainActivity : ComponentActivity(), AsrCallback {
9791
SETTINGS
9892
}
9993

100-
@Throws(IOException::class)
101-
fun readWavPcmBytes(filePath: String): ByteArray {
102-
val wavHeaderSize = 44 // Standard header size for PCM WAV
103-
val file = File(filePath)
104-
val fis = FileInputStream(file)
105-
try {
106-
val totalSize = file.length()
107-
assert(totalSize > wavHeaderSize)
108-
val pcmSize = (totalSize - wavHeaderSize).toInt()
109-
val pcmBytes = ByteArray(pcmSize)
110-
// Skip the header
111-
val skipped = fis.skip(wavHeaderSize.toLong())
112-
if (skipped != wavHeaderSize.toLong()) throw IOException("Failed to skip WAV header")
113-
// Read PCM data
114-
val read = fis.read(pcmBytes)
115-
if (read != pcmSize) throw IOException("Failed to read all PCM data")
116-
return pcmBytes
117-
} finally {
118-
fis.close()
119-
}
120-
}
121-
122-
private fun convertPcm16ToFloat(audioBytes: ByteArray): FloatArray {
123-
val totalSamples = audioBytes.size / 2 // 2 bytes per 16-bit sample
124-
val floatSamples = FloatArray(totalSamples)
125-
126-
// Create ByteBuffer with little-endian byte order (standard for WAV)
127-
val byteBuffer = ByteBuffer.wrap(audioBytes).order(ByteOrder.LITTLE_ENDIAN)
128-
129-
for (i in 0 until totalSamples) {
130-
val sample = byteBuffer.short.toInt()
131-
// Normalize 16-bit PCM to [-1.0, 1.0]
132-
floatSamples[i] = if (sample < 0) {
133-
sample / 32768.0f
134-
} else {
135-
sample / 32767.0f
136-
}
137-
}
138-
139-
return floatSamples
140-
}
141-
14294
override fun onCreate(savedInstanceState: Bundle?) {
14395
super.onCreate(savedInstanceState)
14496

@@ -252,62 +204,21 @@ class MainActivity : ComponentActivity(), AsrCallback {
252204

253205
runOnUiThread {
254206
transcriptionOutput = ""
255-
}
256-
257-
val audioData: FloatArray
258-
val batchSize: Int
259-
val featureDim: Int
260-
val timeSteps: Int
261-
262-
if (settings.hasPreprocessor()) {
263-
// Use preprocessor to convert WAV to mel-spectrogram
264-
Log.v(TAG, "Using preprocessor: ${settings.preprocessorPath}")
265-
runOnUiThread {
266-
statusText = "Processing audio with mel-spectrogram..."
267-
}
268-
269-
val pcmBytes = readWavPcmBytes(wavFilePath)
270-
val inputFloatArray = convertPcm16ToFloat(pcmBytes)
271-
272-
val tensor1 = Tensor.fromBlob(
273-
inputFloatArray,
274-
longArrayOf(inputFloatArray.size.toLong())
275-
)
276-
val module = Module.load(settings.preprocessorPath)
277-
val eValue1 = EValue.from(tensor1)
278-
audioData = module.forward(eValue1)[0].toTensor().dataAsFloatArray
279-
280-
// result shape is [batchSize, timeSteps, featureDim]
281-
batchSize = 1
282-
featureDim = 128 // Whisper uses 128 mel bins
283-
timeSteps = audioData.size / (batchSize * featureDim)
284-
} else {
285-
// No preprocessor: use raw WAV audio directly
286-
Log.v(TAG, "No preprocessor, using raw WAV audio")
287-
runOnUiThread {
288-
statusText = "Processing raw audio..."
289-
}
290-
291-
val pcmBytes = readWavPcmBytes(wavFilePath)
292-
audioData = convertPcm16ToFloat(pcmBytes)
293-
294-
// For raw audio: batchSize=1, timeSteps=numSamples, featureDim=1
295-
batchSize = 1
296-
featureDim = 1 // Raw audio has 1 feature dimension
297-
timeSteps = audioData.size
207+
statusText = "Loading model..."
298208
}
299209

300210
val whisperModule = AsrModule(
301-
settings.modelPath,
302-
settings.tokenizerPath,
303-
settings.dataPath
211+
modelPath = settings.modelPath,
212+
tokenizerPath = settings.tokenizerPath,
213+
dataPath = settings.dataPath.ifBlank { null },
214+
preprocessorPath = settings.preprocessorPath.ifBlank { null }
304215
)
305216

306-
Log.v(TAG, "Starting transcribe with batchSize=$batchSize, timeSteps=$timeSteps, featureDim=$featureDim")
217+
Log.v(TAG, "Starting transcribe for: $wavFilePath")
307218
runOnUiThread {
308219
statusText = "Transcribing..."
309220
}
310-
whisperModule.transcribe(audioData, batchSize, timeSteps, featureDim, this@MainActivity)
221+
whisperModule.transcribe(wavFilePath, this@MainActivity)
311222
Log.v(TAG, "Finished transcribe")
312223

313224
// Display result in Text view instead of Toast

0 commit comments

Comments
 (0)