Skip to content

Commit 9a97485

Browse files
mkopcinsMateusz Kopciński
andauthored
feat: optimize S2T modules, removed react-native-audio-api dependency (#140)
## Description Added caching of encodings on the native side to avoid overhead of sending this data between native and react sides. Removed `react-native-audio-api` and changed it to a suggested way in the docs instead. Added benchmarks for S2T modules. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: Mateusz Kopciński <mateusz.kopcinski@swmansnion.com>
1 parent cf9079d commit 9a97485

47 files changed

Lines changed: 905 additions & 984 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ https://docs.swmansion.com/react-native-executorch
3232
```bash
3333
# Install the package
3434
yarn add react-native-executorch
35-
# Install necessary peer dependency
36-
yarn add react-native-audio-api
3735
cd ios && pod install && cd ..
3836
```
3937

android/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ repositories {
9393
maven { url 'https://jitpack.io' }
9494
}
9595

96-
def kotlin_version = getExtOrDefault("kotlinVersion")
96+
def kotlin_version = rootProject.ext.has("kotlinVersion") ? rootProject.ext.get("kotlinVersion") : project.properties["RnExecutorch_kotlinVersion"]
9797

9898
dependencies {
9999
// For < 0.71, this will be from the local maven repo

android/gradle/wrapper/gradle-wrapper.properties

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
#Tue Mar 11 09:36:23 CET 2025
12
distributionBase=GRADLE_USER_HOME
23
distributionPath=wrapper/dists
3-
distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip
4+
distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip
45
networkTimeout=10000
56
validateDistributionUrl=true
67
zipStoreBase=GRADLE_USER_HOME

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/BaseS2TDecoder.kt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@ abstract class BaseS2TDecoder(
1111
reactApplicationContext: ReactApplicationContext,
1212
) : BaseModel<ReadableArray, Int>(reactApplicationContext) {
1313
protected abstract var methodName: String
14+
lateinit var encoderOutput: EValue
1415

1516
abstract fun setGeneratedTokens(tokens: ReadableArray)
1617

1718
abstract fun getTokensEValue(): EValue
1819

1920
override fun runModel(input: ReadableArray): Int {
20-
val tokensEValue = getTokensEValue()
21+
var encoderOutput = this.encoderOutput
22+
if (input.size() != 0) {
23+
encoderOutput = this.preprocess(input)
24+
}
2125
return this.module
22-
.execute(methodName, tokensEValue, this.preprocess(input))[0]
26+
.execute(methodName, getTokensEValue(), encoderOutput)[0]
2327
.toTensor()
2428
.dataAsLongArray
2529
.last()
@@ -28,8 +32,7 @@ abstract class BaseS2TDecoder(
2832

2933
abstract fun getInputShape(inputLength: Int): LongArray
3034

31-
fun preprocess(input: ReadableArray): EValue {
32-
val inputArray = input.getArray(0)!!
35+
fun preprocess(inputArray: ReadableArray): EValue {
3336
val preprocessorInputShape = this.getInputShape(inputArray.size())
3437
return EValue.from(Tensor.fromBlob(createFloatArray(inputArray), preprocessorInputShape))
3538
}

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/BaseS2TModule.kt

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
11
package com.swmansion.rnexecutorch.models.speechtotext
22

3+
import com.facebook.react.bridge.Arguments
34
import com.facebook.react.bridge.ReadableArray
45
import com.facebook.react.bridge.WritableArray
56
import com.swmansion.rnexecutorch.models.BaseModel
7+
import org.pytorch.executorch.EValue
68

79
abstract class BaseS2TModule {
8-
lateinit var encoder: BaseModel<ReadableArray, WritableArray>
10+
lateinit var encoder: BaseModel<ReadableArray, Array<EValue>>
911
lateinit var decoder: BaseS2TDecoder
1012
abstract var startToken: Int
1113
abstract var eosToken: Int
1214

13-
fun encode(input: ReadableArray): WritableArray = this.encoder.runModel(input)
15+
fun encode(input: ReadableArray): WritableArray {
16+
val encoderOutput = this.encoder.runModel(input)
17+
this.decoder.encoderOutput = encoderOutput[0]
18+
return this.postprocessEncodings(encoderOutput)
19+
}
20+
21+
private fun postprocessEncodings(output: Array<EValue>): WritableArray {
22+
val outputWritableArray: WritableArray = Arguments.createArray()
23+
output[0].toTensor().dataAsFloatArray.map {
24+
outputWritableArray.pushDouble(
25+
it.toDouble(),
26+
)
27+
}
28+
return outputWritableArray
29+
}
1430

1531
abstract fun decode(
1632
prevTokens: ReadableArray,

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/MoonshineDecoder.kt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,9 @@ class MoonshineDecoder(
1010
reactApplicationContext: ReactApplicationContext,
1111
) : BaseS2TDecoder(reactApplicationContext) {
1212
private lateinit var generatedTokens: LongArray
13+
override var methodName: String = "forward_cached"
1314
private var innerDim: Long = 288
1415

15-
override var methodName: String
16-
get() = "forward_cached"
17-
set(value) {}
18-
1916
override fun setGeneratedTokens(tokens: ReadableArray) {
2017
this.generatedTokens = ArrayUtils.createLongArray(tokens)
2118
}
Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,20 @@
11
package com.swmansion.rnexecutorch.models.speechtotext
22

3-
import com.facebook.react.bridge.Arguments
43
import com.facebook.react.bridge.ReactApplicationContext
54
import com.facebook.react.bridge.ReadableArray
6-
import com.facebook.react.bridge.WritableArray
75
import com.swmansion.rnexecutorch.models.BaseModel
86
import com.swmansion.rnexecutorch.utils.ArrayUtils.Companion.createFloatArray
97
import org.pytorch.executorch.EValue
108
import org.pytorch.executorch.Tensor
119

1210
class MoonshineEncoder(
1311
reactApplicationContext: ReactApplicationContext,
14-
) : BaseModel<ReadableArray, WritableArray>(reactApplicationContext) {
15-
override fun runModel(input: ReadableArray): WritableArray = this.postprocess(this.module.forward(this.preprocess(input)))
12+
) : BaseModel<ReadableArray, Array<EValue>>(reactApplicationContext) {
13+
override fun runModel(input: ReadableArray): Array<EValue> = this.module.forward(this.preprocess(input))
1614

17-
fun preprocess(input: ReadableArray): EValue {
15+
private fun preprocess(input: ReadableArray): EValue {
1816
val size = input.size()
1917
val preprocessorInputShape = longArrayOf(1, size.toLong())
2018
return EValue.from(Tensor.fromBlob(createFloatArray(input), preprocessorInputShape))
2119
}
22-
23-
fun postprocess(output: Array<EValue>): WritableArray {
24-
val outputWritableArray: WritableArray = Arguments.createArray()
25-
output[0].toTensor().dataAsFloatArray.map {
26-
outputWritableArray.pushDouble(
27-
it.toDouble(),
28-
)
29-
}
30-
return outputWritableArray
31-
}
3220
}

android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/WhisperDecoder.kt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ class WhisperDecoder(
1010
reactApplicationContext: ReactApplicationContext,
1111
) : BaseS2TDecoder(reactApplicationContext) {
1212
private lateinit var generatedTokens: IntArray
13-
override var methodName: String
14-
get() = "forward"
15-
set(value) {}
13+
override var methodName: String = "forward"
1614

1715
override fun setGeneratedTokens(tokens: ReadableArray) {
1816
this.generatedTokens = ArrayUtils.createIntArray(tokens)
Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package com.swmansion.rnexecutorch.models.speechtotext
22

3-
import com.facebook.react.bridge.Arguments
43
import com.facebook.react.bridge.ReactApplicationContext
54
import com.facebook.react.bridge.ReadableArray
6-
import com.facebook.react.bridge.WritableArray
75
import com.swmansion.rnexecutorch.models.BaseModel
86
import com.swmansion.rnexecutorch.utils.ArrayUtils
97
import com.swmansion.rnexecutorch.utils.STFT
@@ -12,35 +10,20 @@ import org.pytorch.executorch.Tensor
1210

1311
class WhisperEncoder(
1412
reactApplicationContext: ReactApplicationContext,
15-
) : BaseModel<ReadableArray, WritableArray>(reactApplicationContext) {
13+
) : BaseModel<ReadableArray, Array<EValue>>(reactApplicationContext) {
1614
private val fftSize = 512
1715
private val hopLength = 160
1816
private val stftFrameSize = (this.fftSize / 2).toLong()
1917
private val stft = STFT(fftSize, hopLength)
2018

21-
override fun runModel(input: ReadableArray): WritableArray {
22-
val inputEValue = this.preprocess(input)
23-
val hiddenState = this.module.forward(inputEValue)
24-
return this.postprocess(hiddenState)
25-
}
19+
override fun runModel(input: ReadableArray): Array<EValue> = this.module.forward(this.preprocess(input))
2620

27-
fun preprocess(input: ReadableArray): EValue {
21+
private fun preprocess(input: ReadableArray): EValue {
2822
val waveformFloatArray = ArrayUtils.createFloatArray(input)
2923

3024
val stftResult = this.stft.fromWaveform(waveformFloatArray)
3125
val numStftFrames = stftResult.size / this.stftFrameSize
3226
val inputTensor = Tensor.fromBlob(stftResult, longArrayOf(numStftFrames, this.stftFrameSize))
3327
return EValue.from(inputTensor)
3428
}
35-
36-
fun postprocess(output: Array<EValue>): WritableArray {
37-
val outputWritableArray: WritableArray = Arguments.createArray()
38-
39-
output[0].toTensor().dataAsFloatArray.map {
40-
outputWritableArray.pushDouble(
41-
it.toDouble(),
42-
)
43-
}
44-
return outputWritableArray
45-
}
4629
}

0 commit comments

Comments
 (0)