Skip to content

Commit d277202

Browse files
authored
Java ASR Module binding (pytorch#16979)
1 parent 40d94b6 commit d277202

6 files changed

Lines changed: 738 additions & 2 deletions

File tree

extension/android/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,11 @@ if(EXECUTORCH_BUILD_EXTENSION_TRAINING)
170170
endif()
171171

172172
if(EXECUTORCH_BUILD_LLAMA_JNI)
173-
target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/log.cpp)
174-
list(APPEND link_libraries extension_llm_runner)
173+
target_sources(
174+
executorch_jni PRIVATE jni/jni_layer_llama.cpp jni/jni_layer_asr.cpp
175+
jni/log.cpp
176+
)
177+
list(APPEND link_libraries extension_llm_runner extension_asr_runner)
175178
target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1)
176179

177180
if(QNN_SDK_ROOT)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.asr
10+
11+
import org.pytorch.executorch.annotations.Experimental
12+
13+
/**
14+
* Callback interface for ASR (Automatic Speech Recognition) module. Users can implement this
15+
* interface to receive the transcribed tokens and completion notification.
16+
*
17+
* Warning: These APIs are experimental and subject to change without notice
18+
*/
19+
@Experimental
20+
interface AsrCallback {
21+
/**
22+
* Called when a new token is available from JNI. Users will keep getting onToken() invocations
23+
* until transcription finishes.
24+
*
25+
* @param token The decoded text token
26+
*/
27+
fun onToken(token: String)
28+
29+
/**
30+
* Called when transcription is complete.
31+
*
32+
* @param transcription The complete transcription (may be empty if tokens were streamed)
33+
*/
34+
fun onComplete(transcription: String) {}
35+
}
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.asr
10+
11+
import java.io.Closeable
12+
import java.io.File
13+
import java.util.concurrent.atomic.AtomicLong
14+
import org.pytorch.executorch.annotations.Experimental
15+
16+
/**
17+
* AsrModule is a wrapper around the ExecuTorch ASR Runner. It provides a simple interface to
18+
* transcribe audio from WAV files using speech recognition models like Whisper.
19+
*
20+
* The module loads a WAV file, optionally preprocesses it using a preprocessor module (e.g., for
21+
* mel-spectrogram extraction), and then runs the ASR model to generate transcriptions.
22+
*
23+
* Warning: These APIs are experimental and subject to change without notice
24+
*
25+
* @param modelPath Path to the ExecuTorch model file (.pte). The model must expose exactly two
26+
* callable methods named "encoder" and "text_decoder" (these names are required).
27+
* @param tokenizerPath Path to the tokenizer directory containing tokenizer.json
28+
* @param dataPath Optional path to additional data file (e.g., for delegate data)
29+
* @param preprocessorPath Optional path to preprocessor .pte for converting raw audio to features.
30+
* If not provided, raw audio samples will be passed directly to the model.
31+
*/
32+
@Experimental
33+
class AsrModule(
34+
modelPath: String,
35+
tokenizerPath: String,
36+
dataPath: String? = null,
37+
preprocessorPath: String? = null,
38+
) : Closeable {
39+
40+
private val nativeHandle = AtomicLong(0L)
41+
42+
init {
43+
val modelFile = File(modelPath)
44+
require(modelFile.canRead() && modelFile.isFile) { "Cannot load model path $modelPath" }
45+
val tokenizerFile = File(tokenizerPath)
46+
require(tokenizerFile.exists()) { "Cannot load tokenizer path $tokenizerPath" }
47+
if (preprocessorPath != null) {
48+
val preprocessorFile = File(preprocessorPath)
49+
require(preprocessorFile.canRead() && preprocessorFile.isFile) {
50+
"Cannot load preprocessor path $preprocessorPath"
51+
}
52+
}
53+
54+
val handle = nativeCreate(modelPath, tokenizerPath, dataPath, preprocessorPath)
55+
if (handle == 0L) {
56+
throw RuntimeException("Failed to create native AsrModule")
57+
}
58+
nativeHandle.set(handle)
59+
}
60+
61+
companion object {
62+
init {
63+
System.loadLibrary("executorch")
64+
}
65+
66+
@JvmStatic
67+
private external fun nativeCreate(
68+
modelPath: String,
69+
tokenizerPath: String,
70+
dataPath: String?,
71+
preprocessorPath: String?,
72+
): Long
73+
74+
@JvmStatic private external fun nativeDestroy(nativeHandle: Long)
75+
76+
@JvmStatic private external fun nativeLoad(nativeHandle: Long): Int
77+
78+
@JvmStatic private external fun nativeIsLoaded(nativeHandle: Long): Boolean
79+
80+
@JvmStatic
81+
private external fun nativeTranscribe(
82+
nativeHandle: Long,
83+
wavPath: String,
84+
maxNewTokens: Long,
85+
temperature: Float,
86+
decoderStartTokenId: Long,
87+
callback: AsrCallback?,
88+
): Int
89+
}
90+
91+
/** Check if the native handle is valid. */
92+
val isValid: Boolean
93+
get() = nativeHandle.get() != 0L
94+
95+
/** Check if the module is loaded and ready for inference. */
96+
val isLoaded: Boolean
97+
get() {
98+
val handle = nativeHandle.get()
99+
return handle != 0L && nativeIsLoaded(handle)
100+
}
101+
102+
/** Releases native resources. Call this when done with the module. */
103+
fun destroy() {
104+
val handle = nativeHandle.getAndSet(0L)
105+
if (handle != 0L) {
106+
nativeDestroy(handle)
107+
}
108+
}
109+
110+
/** Closeable implementation for use with use {} blocks. */
111+
override fun close() {
112+
destroy()
113+
}
114+
115+
/**
116+
* Force loading the module. Otherwise the model is loaded during first transcribe() call.
117+
*
118+
* @return 0 on success, error code otherwise
119+
* @throws IllegalStateException if the module has been destroyed
120+
*/
121+
fun load(): Int {
122+
val handle = nativeHandle.get()
123+
check(handle != 0L) { "AsrModule has been destroyed" }
124+
return nativeLoad(handle)
125+
}
126+
127+
/**
128+
* Transcribe audio from a WAV file with default configuration.
129+
*
130+
* @param wavPath Path to the WAV audio file
131+
* @param callback Callback to receive tokens, can be null
132+
* @return 0 on success, error code otherwise
133+
* @throws IllegalStateException if the module has been destroyed
134+
*/
135+
fun transcribe(wavPath: String, callback: AsrCallback? = null): Int =
136+
transcribe(wavPath, AsrTranscribeConfig(), callback)
137+
138+
/**
139+
* Transcribe audio from a WAV file with custom configuration.
140+
*
141+
* @param wavPath Path to the WAV audio file
142+
* @param config Configuration for transcription
143+
* @param callback Callback to receive tokens, can be null
144+
* @return 0 on success, error code otherwise
145+
* @throws IllegalStateException if the module has been destroyed
146+
*/
147+
fun transcribe(
148+
wavPath: String,
149+
config: AsrTranscribeConfig,
150+
callback: AsrCallback? = null,
151+
): Int {
152+
val handle = nativeHandle.get()
153+
check(handle != 0L) { "AsrModule has been destroyed" }
154+
val wavFile = File(wavPath)
155+
require(wavFile.canRead() && wavFile.isFile) { "Cannot read WAV file: $wavPath" }
156+
return nativeTranscribe(
157+
handle,
158+
wavPath,
159+
config.maxNewTokens,
160+
config.temperature,
161+
config.decoderStartTokenId,
162+
callback,
163+
)
164+
}
165+
166+
/**
167+
* Transcribe audio from a WAV file and return the full transcription.
168+
*
169+
* This is a blocking call that collects all tokens and returns the complete transcription.
170+
*
171+
* @param wavPath Path to the WAV audio file
172+
* @param config Configuration for transcription
173+
* @return The transcribed text
174+
* @throws RuntimeException if transcription fails
175+
*/
176+
@JvmOverloads
177+
fun transcribeBlocking(
178+
wavPath: String,
179+
config: AsrTranscribeConfig = AsrTranscribeConfig(),
180+
): String {
181+
val result = StringBuilder()
182+
val status =
183+
transcribe(
184+
wavPath,
185+
config,
186+
object : AsrCallback {
187+
override fun onToken(token: String) {
188+
result.append(token)
189+
}
190+
191+
override fun onComplete(transcription: String) {
192+
// Tokens already collected
193+
}
194+
},
195+
)
196+
197+
if (status != 0) {
198+
throw RuntimeException("Transcription failed with error code: $status")
199+
}
200+
201+
return result.toString()
202+
}
203+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package org.pytorch.executorch.extension.asr
10+
11+
import org.pytorch.executorch.annotations.Experimental
12+
13+
/**
14+
* Configuration for ASR transcription.
15+
*
16+
* Warning: These APIs are experimental and subject to change without notice
17+
*
18+
* @property maxNewTokens Maximum number of new tokens to generate (must be positive)
19+
* @property temperature Temperature for sampling. 0.0 means greedy decoding
20+
* @property decoderStartTokenId The token ID to start decoding with (e.g., language token for
21+
* Whisper)
22+
*/
23+
@Experimental
24+
data class AsrTranscribeConfig(
25+
val maxNewTokens: Long = 128,
26+
val temperature: Float = 0.0f,
27+
val decoderStartTokenId: Long = 0,
28+
) {
29+
init {
30+
require(maxNewTokens > 0) { "maxNewTokens must be positive" }
31+
require(temperature >= 0) { "temperature must be non-negative" }
32+
}
33+
34+
/** Builder class for AsrTranscribeConfig for Java interoperability. */
35+
class Builder {
36+
private var maxNewTokens: Long = 128
37+
private var temperature: Float = 0.0f
38+
private var decoderStartTokenId: Long = 0
39+
40+
fun setMaxNewTokens(maxNewTokens: Long) = apply {
41+
require(maxNewTokens > 0) { "maxNewTokens must be positive" }
42+
this.maxNewTokens = maxNewTokens
43+
}
44+
45+
fun setTemperature(temperature: Float) = apply {
46+
require(temperature >= 0) { "temperature must be non-negative" }
47+
this.temperature = temperature
48+
}
49+
50+
fun setDecoderStartTokenId(decoderStartTokenId: Long) = apply {
51+
this.decoderStartTokenId = decoderStartTokenId
52+
}
53+
54+
fun build() =
55+
AsrTranscribeConfig(
56+
maxNewTokens = maxNewTokens,
57+
temperature = temperature,
58+
decoderStartTokenId = decoderStartTokenId,
59+
)
60+
}
61+
}

0 commit comments

Comments
 (0)