Skip to content

Commit 1578d52

Browse files
authored
feat: Flag for mean pooling text embeddings (#384)
## Description Some models do mean pooling, some don't. We can set flag to take that into account - we should set this flag in all presets for our models while completing: #359 ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Related issues #353 ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have updated the documentation accordingly - [ ] My changes generate no new warnings **They actually do, if users don't set the flag** ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent df4eae4 commit 1578d52

10 files changed

Lines changed: 61 additions & 21 deletions

File tree

docs/docs/natural-language-processing/useTextEmbeddings.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import {
3333
const model = useTextEmbeddings({
3434
modelSource: ALL_MINILM_L6_V2,
3535
tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER,
36+
meanPooling: true,
3637
});
3738

3839
try {
@@ -50,6 +51,8 @@ A string that specifies the location of the model binary. For more information,
5051
**`tokenizerSource`**
5152
A string that specifies the location of the tokenizer JSON file.
5253

54+
**`meanPooling?`** - Boolean that controls whether we perform mean pooling on the model output or not. If not set, it will default to true and display warning.
55+
5356
**`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook.
5457

5558
### Returns
@@ -86,6 +89,7 @@ function App() {
8689
const model = useTextEmbeddings({
8790
modelSource: ALL_MINILM_L6_V2,
8891
tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER,
92+
meanPooling: true,
8993
});
9094

9195
...

docs/docs/typescript-api/TextEmbeddingsModule.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ const embedding = await TextEmbeddingsModule.forward('Hello World!');
2222

2323
### Methods
2424

25-
| Method | Type | Description |
26-
| -------------------- | ------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
27-
| `load` | `(modelSource: ResourceSource, tokenizerSource: ResourceSource): Promise<void>` | Loads the model, where `modelSource` is a string that specifies the location of the model binary and `tokenizerSource` is a string that specifies the location of the tokenizer JSON file. |
28-
| `forward` | `(input: string): Promise<number[]>` | Executes the model's forward pass, where `input` is a text that will be embedded. |
29-
| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. |
25+
| Method | Type | Description |
26+
| -------------------- | ------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
27+
| `load` | `(modelSource: ResourceSource, tokenizerSource: ResourceSource, meanPooling?: boolean): Promise<void>` | Loads the model, where `modelSource` is a string that specifies the location of the model binary, `tokenizerSource` is a string that specifies the location of the tokenizer JSON file, and `meanPooling` controls when to perform pooling on model outputs. |
28+
| `forward` | `(input: string): Promise<number[]>` | Executes the model's forward pass, where `input` is a text that will be embedded. |
29+
| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. |
3030

3131
<details>
3232
<summary>Type definitions</summary>
@@ -39,7 +39,7 @@ type ResourceSource = string | number | object;
3939

4040
## Loading the model
4141

42-
To load the model, use the `load` method. It accepts the `modelSource` which is a string that specifies the location of the model binary and `tokenizerSource` which is a string that specifies the location of the tokenizer JSON file. For more information, take a look at [loading models](../fundamentals/loading-models.md) page. This method returns a promise, which can resolve to an error or void.
42+
To load the model, use the `load` method. It accepts the `modelSource` which is a string that specifies the location of the model binary, `tokenizerSource` which is a string that specifies the location of the tokenizer JSON file, and optional `meanPooling` flag controls when to perform pooling on model outputs. For more information, take a look at [loading models](../fundamentals/loading-models.md) page. This method returns a promise, which can resolve to an error or void.
4343

4444
## Running the model
4545

packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/TextEmbeddings.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ class TextEmbeddings(
3434

3535
override fun forward(
3636
input: String,
37+
meanPooling: Boolean,
3738
promise: Promise,
3839
) {
3940
try {
40-
val output = textEmbeddingsModel.runModel(input)
41+
val output = textEmbeddingsModel.runModel(input, meanPooling)
4142
val writableArray = WritableNativeArray()
4243
output.forEach { writableArray.pushDouble(it) }
4344

packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsModel.kt

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,24 @@ class TextEmbeddingsModel(
2424
fun postprocess(
2525
modelOutput: FloatArray, // [tokens * embedding_dim]
2626
attentionMask: LongArray, // [tokens]
27+
meanPooling: Boolean,
2728
): DoubleArray {
28-
val modelOutputDouble = modelOutput.map { it.toDouble() }.toDoubleArray()
29-
val embeddings = TextEmbeddingsUtils.meanPooling(modelOutputDouble, attentionMask)
30-
return TextEmbeddingsUtils.normalize(embeddings)
29+
var embeddings = modelOutput.map { it.toDouble() }.toDoubleArray()
30+
if (meanPooling) {
31+
embeddings = TextEmbeddingsUtils.meanPooling(embeddings, attentionMask)
32+
}
33+
embeddings = TextEmbeddingsUtils.normalize(embeddings)
34+
return embeddings
3135
}
3236

3337
override fun runModel(input: String): DoubleArray {
38+
return runModel(input, true)
39+
}
40+
41+
fun runModel(
42+
input: String,
43+
meanPooling: Boolean,
44+
): DoubleArray {
3445
val modelInput = preprocess(input)
3546
val inputsIds = modelInput[0]
3647
val attentionMask = modelInput[1]
@@ -43,6 +54,6 @@ class TextEmbeddingsModel(
4354

4455
val modelOutput = forward(inputIdsEValue, attentionMaskEValue)[0].toTensor().dataAsFloatArray
4556

46-
return postprocess(modelOutput, attentionMask)
57+
return postprocess(modelOutput, attentionMask, meanPooling)
4758
}
4859
}

packages/react-native-executorch/ios/RnExecutorch/TextEmbeddings.mm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ - (void)loadModule:(NSString *)modelSource
4141
}
4242

4343
- (void)forward:(NSString *)input
44+
meanPooling:(bool)meanPooling
4445
resolve:(RCTPromiseResolveBlock)resolve
4546
reject:(RCTPromiseRejectBlock)reject {
4647
@try {
47-
resolve([model runModel:input]);
48+
resolve([model runModel:input meanPooling:meanPooling]);
4849
return;
4950
} @catch (NSException *exception) {
5051
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);

packages/react-native-executorch/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
- (void)loadTokenizer:(NSString *)tokenizerSource;
1010
- (NSArray *)preprocess:(NSString *)input;
1111
- (NSArray *)runModel:(NSString *)input;
12+
- (NSArray *)runModel:(NSString *)input meanPooling:(bool)meanPooling;
1213
- (NSArray *)postprocess:(NSArray *)input
1314
attentionMask:(NSArray *)attentionMask;
1415

packages/react-native-executorch/ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.mm

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@ - (NSArray *)preprocess:(NSString *)input {
1414

1515
- (NSArray *)postprocess:(NSArray *)modelOutput // [tokens * embedding_dim]
1616
attentionMask:(NSArray *)attentionMask // [tokens]
17-
{
18-
NSArray *embeddings = [TextEmbeddingsUtils meanPooling:modelOutput
19-
attentionMask:attentionMask];
20-
return [TextEmbeddingsUtils normalize:embeddings];
17+
meanPooling:(bool)meanPooling {
18+
NSArray *embeddings = modelOutput;
19+
if (meanPooling) {
20+
embeddings = [TextEmbeddingsUtils meanPooling:modelOutput
21+
attentionMask:attentionMask];
22+
}
23+
embeddings = [TextEmbeddingsUtils normalize:embeddings];
24+
return embeddings;
2125
}
2226

2327
- (NSArray *)runModel:(NSString *)input {
28+
return [self runModel:input meanPooling:true];
29+
}
30+
31+
- (NSArray *)runModel:(NSString *)input meanPooling:(bool)meanPooling {
2432
NSArray *modelInput = [self preprocess:input];
2533

2634
NSMutableArray *inputTypes = [NSMutableArray arrayWithObjects:@4, @4, nil];
@@ -34,7 +42,9 @@ - (NSArray *)runModel:(NSString *)input {
3442
NSArray *modelOutput = [self forward:modelInput
3543
shapes:shapes
3644
inputTypes:inputTypes];
37-
return [self postprocess:modelOutput[0] attentionMask:modelInput[1]];
45+
return [self postprocess:modelOutput[0]
46+
attentionMask:modelInput[1]
47+
meanPooling:meanPooling];
3848
}
3949

4050
- (void)loadTokenizer:(NSString *)tokenizerSource {

packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ import { useModule } from '../useModule';
55
export const useTextEmbeddings = ({
66
modelSource,
77
tokenizerSource,
8+
meanPooling,
89
preventLoad = false,
910
}: {
1011
modelSource: ResourceSource;
1112
tokenizerSource: ResourceSource;
13+
meanPooling?: boolean;
1214
preventLoad?: boolean;
1315
}) =>
1416
useModule({
1517
module: TextEmbeddingsModule,
16-
loadArgs: [modelSource, tokenizerSource],
18+
loadArgs: [modelSource, tokenizerSource, meanPooling],
1719
preventLoad,
1820
});

packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,25 @@ import { BaseModule } from '../BaseModule';
44

55
export class TextEmbeddingsModule extends BaseModule {
66
protected static override nativeModule = TextEmbeddingsNativeModule;
7+
private static meanPooling: boolean;
78

89
static override async load(
910
modelSource: ResourceSource,
10-
tokenizerSource: ResourceSource
11+
tokenizerSource: ResourceSource,
12+
meanPooling?: boolean
1113
) {
14+
if (meanPooling === undefined) {
15+
console.warn(
16+
"You haven't passed meanPooling flag. It is defaulting to true. If your model doesn't require pooling it may misbehave."
17+
);
18+
meanPooling = true;
19+
}
20+
1221
await super.load([modelSource, tokenizerSource]);
22+
this.meanPooling = meanPooling;
1323
}
1424

1525
static override async forward(input: string): Promise<number[]> {
16-
return this.nativeModule.forward(input);
26+
return this.nativeModule.forward(input, this.meanPooling);
1727
}
1828
}

packages/react-native-executorch/src/native/NativeTextEmbeddings.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ import { TurboModuleRegistry } from 'react-native';
33

44
export interface Spec extends TurboModule {
55
loadModule(modelSource: string, tokenizerSource: string): Promise<number>;
6-
forward(input: string): Promise<number[]>;
6+
forward(input: string, meanPooling: boolean): Promise<number[]>;
77
}
88

99
export default TurboModuleRegistry.get<Spec>('TextEmbeddings');

0 commit comments

Comments
 (0)