Skip to content

Commit 536e1c4

Browse files
chmjkbmsluszniakmkopcins
authored
fix: make LLMModule and SpeechToTextModule non-static post C++ port (#479)
## Description <!-- Provide a concise and descriptive summary of the changes implemented in this PR. --> ### Introduces a breaking change? - [x] Yes - [ ] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [x] New feature (change which adds functionality) - [x] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) ### Tested on - [ ] iOS - [ ] Android ### Testing instructions <!-- Provide step-by-step instructions on how to test your changes. Include setup details if necessary. --> ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> Co-authored-by: Mateusz Kopcinski <120639731+mkopcins@users.noreply.github.com>
1 parent ffc1b4e commit 536e1c4

8 files changed

Lines changed: 174 additions & 128 deletions

File tree

docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md

Lines changed: 35 additions & 26 deletions
Large diffs are not rendered by default.

docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md

Lines changed: 45 additions & 15 deletions
Large diffs are not rendered by default.

packages/react-native-executorch/src/controllers/LLMController.ts

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,19 @@ export class LLMController {
3131
private messageHistoryCallback: (messageHistory: Message[]) => void;
3232
private isReadyCallback: (isReady: boolean) => void;
3333
private isGeneratingCallback: (isGenerating: boolean) => void;
34-
private onDownloadProgressCallback:
35-
| ((downloadProgress: number) => void)
36-
| undefined;
3734

3835
constructor({
3936
tokenCallback,
4037
responseCallback,
4138
messageHistoryCallback,
4239
isReadyCallback,
4340
isGeneratingCallback,
44-
onDownloadProgressCallback,
4541
}: {
4642
tokenCallback?: (token: string) => void;
4743
responseCallback?: (response: string) => void;
4844
messageHistoryCallback?: (messageHistory: Message[]) => void;
4945
isReadyCallback?: (isReady: boolean) => void;
5046
isGeneratingCallback?: (isGenerating: boolean) => void;
51-
onDownloadProgressCallback?: (downloadProgress: number) => void;
5247
}) {
5348
if (responseCallback !== undefined) {
5449
Logger.warn(
@@ -74,8 +69,6 @@ export class LLMController {
7469
this._isGenerating = isGenerating;
7570
isGeneratingCallback?.(isGenerating);
7671
};
77-
78-
this.onDownloadProgressCallback = onDownloadProgressCallback;
7972
}
8073

8174
public get response() {
@@ -95,10 +88,12 @@ export class LLMController {
9588
modelSource,
9689
tokenizerSource,
9790
tokenizerConfigSource,
91+
onDownloadProgressCallback,
9892
}: {
9993
modelSource: ResourceSource;
10094
tokenizerSource: ResourceSource;
10195
tokenizerConfigSource: ResourceSource;
96+
onDownloadProgressCallback?: (downloadProgress: number) => void;
10297
}) {
10398
// reset inner state when loading new model
10499
this.responseCallback('');
@@ -108,7 +103,7 @@ export class LLMController {
108103

109104
try {
110105
const paths = await ResourceFetcher.fetch(
111-
this.onDownloadProgressCallback,
106+
onDownloadProgressCallback,
112107
tokenizerSource,
113108
tokenizerConfigSource,
114109
modelSource

packages/react-native-executorch/src/controllers/SpeechToTextController.ts

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,13 @@ export class SpeechToTextController {
3434

3535
// User callbacks
3636
private decodedTranscribeCallback: (sequence: number[]) => void;
37-
private modelDownloadProgressCallback:
38-
| ((downloadProgress: number) => void)
39-
| undefined;
4037
private isReadyCallback: (isReady: boolean) => void;
4138
private isGeneratingCallback: (isGenerating: boolean) => void;
4239
private onErrorCallback: (error: any) => void;
4340
private config!: ModelConfig;
4441

4542
constructor({
4643
transcribeCallback,
47-
modelDownloadProgressCallback,
4844
isReadyCallback,
4945
isGeneratingCallback,
5046
onErrorCallback,
@@ -53,7 +49,6 @@ export class SpeechToTextController {
5349
streamingConfig,
5450
}: {
5551
transcribeCallback: (sequence: string) => void;
56-
modelDownloadProgressCallback?: (downloadProgress: number) => void;
5752
isReadyCallback?: (isReady: boolean) => void;
5853
isGeneratingCallback?: (isGenerating: boolean) => void;
5954
onErrorCallback?: (error: Error | undefined) => void;
@@ -64,7 +59,6 @@ export class SpeechToTextController {
6459
this.tokenizerModule = new TokenizerModule();
6560
this.decodedTranscribeCallback = async (seq) =>
6661
transcribeCallback(await this.tokenIdsToText(seq));
67-
this.modelDownloadProgressCallback = modelDownloadProgressCallback;
6862
this.isReadyCallback = (isReady) => {
6963
this.isReady = isReady;
7064
isReadyCallback?.(isReady);
@@ -88,12 +82,19 @@ export class SpeechToTextController {
8882
);
8983
}
9084

91-
public async loadModel(
92-
modelName: AvailableModels,
93-
encoderSource?: ResourceSource,
94-
decoderSource?: ResourceSource,
95-
tokenizerSource?: ResourceSource
96-
) {
85+
public async load({
86+
modelName,
87+
encoderSource,
88+
decoderSource,
89+
tokenizerSource,
90+
onDownloadProgressCallback,
91+
}: {
92+
modelName: AvailableModels;
93+
encoderSource?: ResourceSource;
94+
decoderSource?: ResourceSource;
95+
tokenizerSource?: ResourceSource;
96+
onDownloadProgressCallback?: (downloadProgress: number) => void;
97+
}) {
9798
this.onErrorCallback(undefined);
9899
this.isReadyCallback(false);
99100
this.config = MODEL_CONFIGS[modelName];
@@ -103,7 +104,7 @@ export class SpeechToTextController {
103104
tokenizerSource || this.config.tokenizer.source
104105
);
105106
const paths = await ResourceFetcher.fetch(
106-
this.modelDownloadProgressCallback,
107+
onDownloadProgressCallback,
107108
encoderSource || this.config.sources.encoder,
108109
decoderSource || this.config.sources.decoder
109110
);

packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ export const useLLM = ({
4343
messageHistoryCallback: setMessageHistory,
4444
isReadyCallback: setIsReady,
4545
isGeneratingCallback: setIsGenerating,
46-
onDownloadProgressCallback: setDownloadProgress,
4746
}),
4847
[tokenCallback]
4948
);
@@ -60,6 +59,7 @@ export const useLLM = ({
6059
modelSource,
6160
tokenizerSource,
6261
tokenizerConfigSource,
62+
onDownloadProgressCallback: setDownloadProgress,
6363
});
6464
} catch (e) {
6565
setError(e);

packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ export const useSpeechToText = ({
6060
isReadyCallback: setIsReady,
6161
isGeneratingCallback: setIsGenerating,
6262
onErrorCallback: setError,
63-
modelDownloadProgressCallback: setDownloadProgress,
6463
}),
6564
[]
6665
);
@@ -71,12 +70,13 @@ export const useSpeechToText = ({
7170

7271
useEffect(() => {
7372
const loadModel = async () => {
74-
await model.loadModel(
73+
await model.load({
7574
modelName,
7675
encoderSource,
7776
decoderSource,
78-
tokenizerSource
79-
);
77+
tokenizerSource,
78+
onDownloadProgressCallback: setDownloadProgress,
79+
});
8080
};
8181
if (!preventLoad) {
8282
loadModel();

packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,52 @@ import { ResourceSource } from '../../types/common';
33
import { ChatConfig, LLMTool, Message, ToolsConfig } from '../../types/llm';
44

55
export class LLMModule {
6-
static controller: LLMController;
6+
private controller: LLMController;
77

8-
static async load({
8+
constructor({
9+
tokenCallback,
10+
responseCallback,
11+
messageHistoryCallback,
12+
}: {
13+
tokenCallback?: (token: string) => void;
14+
responseCallback?: (response: string) => void;
15+
messageHistoryCallback?: (messageHistory: Message[]) => void;
16+
} = {}) {
17+
this.controller = new LLMController({
18+
tokenCallback,
19+
responseCallback,
20+
messageHistoryCallback,
21+
});
22+
}
23+
24+
async load({
925
modelSource,
1026
tokenizerSource,
1127
tokenizerConfigSource,
1228
onDownloadProgressCallback,
13-
tokenCallback,
14-
responseCallback,
15-
messageHistoryCallback,
1629
}: {
1730
modelSource: ResourceSource;
1831
tokenizerSource: ResourceSource;
1932
tokenizerConfigSource: ResourceSource;
2033
onDownloadProgressCallback?: (_downloadProgress: number) => void;
21-
tokenCallback?: (token: string) => void;
22-
responseCallback?: (response: string) => void;
23-
messageHistoryCallback?: (messageHistory: Message[]) => void;
2434
}) {
25-
this.controller = new LLMController({
26-
tokenCallback: tokenCallback,
27-
responseCallback: responseCallback,
28-
messageHistoryCallback: messageHistoryCallback,
29-
onDownloadProgressCallback: onDownloadProgressCallback,
30-
});
3135
await this.controller.load({
3236
modelSource,
3337
tokenizerSource,
3438
tokenizerConfigSource,
39+
onDownloadProgressCallback,
3540
});
3641
}
3742

38-
static setTokenCallback({
43+
setTokenCallback({
3944
tokenCallback,
4045
}: {
4146
tokenCallback: (token: string) => void;
4247
}) {
4348
this.controller.setTokenCallback(tokenCallback);
4449
}
4550

46-
static configure({
51+
configure({
4752
chatConfig,
4853
toolsConfig,
4954
}: {
@@ -53,34 +58,31 @@ export class LLMModule {
5358
this.controller.configure({ chatConfig, toolsConfig });
5459
}
5560

56-
static async forward(input: string): Promise<string> {
61+
async forward(input: string): Promise<string> {
5762
await this.controller.forward(input);
5863
return this.controller.response;
5964
}
6065

61-
static async generate(
62-
messages: Message[],
63-
tools?: LLMTool[]
64-
): Promise<string> {
66+
async generate(messages: Message[], tools?: LLMTool[]): Promise<string> {
6567
await this.controller.generate(messages, tools);
6668
return this.controller.response;
6769
}
6870

69-
static async sendMessage(message: string): Promise<Message[]> {
71+
async sendMessage(message: string): Promise<Message[]> {
7072
await this.controller.sendMessage(message);
7173
return this.controller.messageHistory;
7274
}
7375

74-
static async deleteMessage(index: number): Promise<Message[]> {
75-
await this.controller.deleteMessage(index);
76+
deleteMessage(index: number): Message[] {
77+
this.controller.deleteMessage(index);
7678
return this.controller.messageHistory;
7779
}
7880

79-
static interrupt() {
81+
interrupt() {
8082
this.controller.interrupt();
8183
}
8284

83-
static delete() {
85+
delete() {
8486
this.controller.delete();
8587
}
8688
}

packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,70 +4,79 @@ import { AvailableModels, SpeechToTextLanguage } from '../../types/stt';
44
import { STREAMING_ACTION } from '../../constants/sttDefaults';
55

66
export class SpeechToTextModule {
7-
static module: SpeechToTextController;
7+
private module: SpeechToTextController;
88

9-
static onDownloadProgressCallback = (_downloadProgress: number) => {};
10-
11-
static async load(
12-
modelName: AvailableModels,
13-
transcribeCallback: (sequence: string) => void,
14-
modelDownloadProgressCallback?: (downloadProgress: number) => void,
15-
encoderSource?: ResourceSource,
16-
decoderSource?: ResourceSource,
17-
tokenizerSource?: ResourceSource,
9+
constructor({
10+
transcribeCallback,
11+
overlapSeconds,
12+
windowSize,
13+
streamingConfig,
14+
}: {
15+
transcribeCallback?: (sequence: string) => void;
1816
overlapSeconds?: ConstructorParameters<
1917
typeof SpeechToTextController
20-
>['0']['overlapSeconds'],
18+
>['0']['overlapSeconds'];
2119
windowSize?: ConstructorParameters<
2220
typeof SpeechToTextController
23-
>['0']['windowSize'],
21+
>['0']['windowSize'];
2422
streamingConfig?: ConstructorParameters<
2523
typeof SpeechToTextController
26-
>['0']['streamingConfig']
27-
) {
24+
>['0']['streamingConfig'];
25+
} = {}) {
2826
this.module = new SpeechToTextController({
29-
transcribeCallback: transcribeCallback,
30-
modelDownloadProgressCallback: modelDownloadProgressCallback,
31-
overlapSeconds: overlapSeconds,
32-
windowSize: windowSize,
33-
streamingConfig: streamingConfig,
27+
transcribeCallback: transcribeCallback || (() => {}),
28+
overlapSeconds,
29+
windowSize,
30+
streamingConfig,
3431
});
35-
await this.module.loadModel(
36-
(modelName = modelName),
37-
(encoderSource = encoderSource),
38-
(decoderSource = decoderSource),
39-
(tokenizerSource = tokenizerSource)
40-
);
4132
}
4233

43-
static configureStreaming(
34+
async load({
35+
modelName,
36+
encoderSource,
37+
decoderSource,
38+
tokenizerSource,
39+
onDownloadProgressCallback,
40+
}: {
41+
modelName: AvailableModels;
42+
encoderSource?: ResourceSource;
43+
decoderSource?: ResourceSource;
44+
tokenizerSource?: ResourceSource;
45+
onDownloadProgressCallback?: (downloadProgress: number) => void;
46+
}) {
47+
await this.module.load({
48+
modelName,
49+
encoderSource,
50+
decoderSource,
51+
tokenizerSource,
52+
onDownloadProgressCallback,
53+
});
54+
}
55+
56+
configureStreaming(
4457
overlapSeconds: Parameters<SpeechToTextController['configureStreaming']>[0],
4558
windowSize: Parameters<SpeechToTextController['configureStreaming']>[1],
4659
streamingConfig: Parameters<SpeechToTextController['configureStreaming']>[2]
4760
) {
48-
this.module?.configureStreaming(
49-
overlapSeconds,
50-
windowSize,
51-
streamingConfig
52-
);
61+
this.module.configureStreaming(overlapSeconds, windowSize, streamingConfig);
5362
}
5463

55-
static async encode(waveform: Float32Array) {
64+
async encode(waveform: Float32Array) {
5665
return await this.module.encode(waveform);
5766
}
5867

59-
static async decode(seq: number[]) {
68+
async decode(seq: number[]) {
6069
return await this.module.decode(seq);
6170
}
6271

63-
static async transcribe(
72+
async transcribe(
6473
waveform: number[],
6574
audioLanguage?: SpeechToTextLanguage
6675
): ReturnType<SpeechToTextController['transcribe']> {
6776
return await this.module.transcribe(waveform, audioLanguage);
6877
}
6978

70-
static async streamingTranscribe(
79+
async streamingTranscribe(
7180
streamAction: STREAMING_ACTION,
7281
waveform?: number[],
7382
audioLanguage?: SpeechToTextLanguage

0 commit comments

Comments
 (0)