Skip to content

Commit ba35183

Browse files
chmjkbclaude
andcommitted
chore: migrate TextToImageModule to factory pattern, add TextToImageModelName type
- Add TextToImageModelName union type - Add modelName to TextToImageProps.model - TextToImageModule: private constructor, fromModelName, fromCustomModel - useTextToImage: use factory, add model.modelName to deps Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent b62201a commit ba35183

3 files changed

Lines changed: 174 additions & 74 deletions

File tree

packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,71 @@ export const useTextToImage = ({
2020
const [isGenerating, setIsGenerating] = useState(false);
2121
const [downloadProgress, setDownloadProgress] = useState(0);
2222
const [error, setError] = useState<RnExecutorchError | null>(null);
23-
24-
const [module] = useState(() => new TextToImageModule(inferenceCallback));
23+
const [moduleInstance, setModuleInstance] =
24+
useState<TextToImageModule | null>(null);
2525

2626
useEffect(() => {
2727
if (preventLoad) return;
2828

29-
(async () => {
30-
setDownloadProgress(0);
31-
setError(null);
32-
try {
33-
setIsReady(false);
34-
await module.load(model, setDownloadProgress);
35-
setIsReady(true);
36-
} catch (err) {
37-
setError(parseUnknownError(err));
29+
let active = true;
30+
setDownloadProgress(0);
31+
setError(null);
32+
setIsReady(false);
33+
34+
TextToImageModule.fromModelName(
35+
{
36+
modelName: model.modelName,
37+
tokenizerSource: model.tokenizerSource,
38+
schedulerSource: model.schedulerSource,
39+
encoderSource: model.encoderSource,
40+
unetSource: model.unetSource,
41+
decoderSource: model.decoderSource,
42+
inferenceCallback,
43+
},
44+
(p) => {
45+
if (active) setDownloadProgress(p);
3846
}
39-
})();
47+
)
48+
.then((mod) => {
49+
if (!active) {
50+
mod.delete();
51+
return;
52+
}
53+
setModuleInstance((prev) => {
54+
prev?.delete();
55+
return mod;
56+
});
57+
setIsReady(true);
58+
})
59+
.catch((err) => {
60+
if (active) setError(parseUnknownError(err));
61+
});
4062

4163
return () => {
42-
module.delete();
64+
active = false;
65+
setModuleInstance((prev) => {
66+
prev?.delete();
67+
return null;
68+
});
4369
};
44-
}, [module, model, preventLoad]);
70+
// eslint-disable-next-line react-hooks/exhaustive-deps
71+
}, [
72+
model.modelName,
73+
model.tokenizerSource,
74+
model.schedulerSource,
75+
model.encoderSource,
76+
model.unetSource,
77+
model.decoderSource,
78+
preventLoad,
79+
]);
4580

4681
const generate = async (
4782
input: string,
4883
imageSize?: number,
4984
numSteps?: number,
5085
seed?: number
5186
): Promise<string> => {
52-
if (!isReady)
87+
if (!isReady || !moduleInstance)
5388
throw new RnExecutorchError(
5489
RnExecutorchErrorCode.ModuleNotLoaded,
5590
'The model is currently not loaded. Please load the model before calling forward().'
@@ -61,17 +96,17 @@ export const useTextToImage = ({
6196
);
6297
try {
6398
setIsGenerating(true);
64-
return await module.forward(input, imageSize, numSteps, seed);
99+
return await moduleInstance.forward(input, imageSize, numSteps, seed);
65100
} finally {
66101
setIsGenerating(false);
67102
}
68103
};
69104

70105
const interrupt = useCallback(() => {
71-
if (isGenerating) {
72-
module.interrupt();
106+
if (isGenerating && moduleInstance) {
107+
moduleInstance.interrupt();
73108
}
74-
}, [module, isGenerating]);
109+
}, [moduleInstance, isGenerating]);
75110

76111
return {
77112
isReady,

packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts

Lines changed: 106 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { ResourceFetcher } from '../../utils/ResourceFetcher';
22
import { ResourceSource } from '../../types/common';
3+
import { TextToImageModelName } from '../../types/tti';
34
import { BaseModule } from '../BaseModule';
45

56
import { PNG } from 'pngjs/browser';
@@ -15,82 +16,132 @@ import { Logger } from '../../common/Logger';
1516
export class TextToImageModule extends BaseModule {
1617
private inferenceCallback: (stepIdx: number) => void;
1718

18-
/**
19-
* Creates a new instance of `TextToImageModule` with optional callback on inference step.
20-
*
21-
* @param inferenceCallback - Optional callback function that receives the current step index during inference.
22-
*/
23-
constructor(inferenceCallback?: (stepIdx: number) => void) {
19+
private constructor(inferenceCallback?: (stepIdx: number) => void) {
2420
super();
2521
this.inferenceCallback = (stepIdx: number) => {
2622
inferenceCallback?.(stepIdx);
2723
};
2824
}
2925

3026
/**
31-
* Loads the model from specified resources.
27+
* Creates a Text to Image instance for a built-in model.
3228
*
33-
* @param model - Object containing sources for tokenizer, scheduler, encoder, unet, and decoder.
34-
* @param onDownloadProgressCallback - Optional callback to monitor download progress.
29+
* @param namedSources - An object specifying the model name, pipeline sources, and optional inference callback.
30+
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
31+
* @returns A Promise resolving to a `TextToImageModule` instance.
32+
*
33+
* @example
34+
* ```ts
35+
* import { TextToImageModule, BK_SDM_TINY_VPRED_512 } from 'react-native-executorch';
36+
* const tti = await TextToImageModule.fromModelName(BK_SDM_TINY_VPRED_512);
37+
* ```
3538
*/
36-
async load(
37-
model: {
39+
static async fromModelName(
40+
namedSources: {
41+
modelName: TextToImageModelName;
3842
tokenizerSource: ResourceSource;
3943
schedulerSource: ResourceSource;
4044
encoderSource: ResourceSource;
4145
unetSource: ResourceSource;
4246
decoderSource: ResourceSource;
47+
inferenceCallback?: (stepIdx: number) => void;
4348
},
44-
onDownloadProgressCallback: (progress: number) => void = () => {}
45-
): Promise<void> {
49+
onDownloadProgress: (progress: number) => void = () => {}
50+
): Promise<TextToImageModule> {
51+
const instance = new TextToImageModule(namedSources.inferenceCallback);
4652
try {
47-
const results = await ResourceFetcher.fetch(
48-
onDownloadProgressCallback,
49-
model.tokenizerSource,
50-
model.schedulerSource,
51-
model.encoderSource,
52-
model.unetSource,
53-
model.decoderSource
54-
);
55-
if (!results) {
56-
throw new RnExecutorchError(
57-
RnExecutorchErrorCode.DownloadInterrupted,
58-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
59-
);
60-
}
61-
const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] =
62-
results;
53+
await instance.internalLoad(namedSources, onDownloadProgress);
54+
return instance;
55+
} catch (error) {
56+
Logger.error('Load failed:', error);
57+
throw parseUnknownError(error);
58+
}
59+
}
6360

64-
if (
65-
!tokenizerPath ||
66-
!schedulerPath ||
67-
!encoderPath ||
68-
!unetPath ||
69-
!decoderPath
70-
) {
71-
throw new RnExecutorchError(
72-
RnExecutorchErrorCode.DownloadInterrupted,
73-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
74-
);
75-
}
61+
/**
62+
* Creates a Text to Image instance with user-provided model binaries.
63+
* Use this when working with a custom-exported diffusion pipeline.
64+
* Internally uses `'custom'` as the model name for telemetry.
65+
*
66+
* @param sources - An object containing the pipeline source paths.
67+
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
68+
* @param inferenceCallback - Optional callback triggered after each diffusion step.
69+
* @returns A Promise resolving to a `TextToImageModule` instance.
70+
*/
71+
static fromCustomModel(
72+
sources: {
73+
tokenizerSource: ResourceSource;
74+
schedulerSource: ResourceSource;
75+
encoderSource: ResourceSource;
76+
unetSource: ResourceSource;
77+
decoderSource: ResourceSource;
78+
},
79+
onDownloadProgress: (progress: number) => void = () => {},
80+
inferenceCallback?: (stepIdx: number) => void
81+
): Promise<TextToImageModule> {
82+
return TextToImageModule.fromModelName(
83+
{
84+
modelName: 'custom' as TextToImageModelName,
85+
...sources,
86+
inferenceCallback,
87+
},
88+
onDownloadProgress
89+
);
90+
}
7691

77-
const response = await fetch('file://' + schedulerPath);
78-
const schedulerConfig = await response.json();
92+
private async internalLoad(
93+
model: {
94+
tokenizerSource: ResourceSource;
95+
schedulerSource: ResourceSource;
96+
encoderSource: ResourceSource;
97+
unetSource: ResourceSource;
98+
decoderSource: ResourceSource;
99+
},
100+
onDownloadProgressCallback: (progress: number) => void
101+
): Promise<void> {
102+
const results = await ResourceFetcher.fetch(
103+
onDownloadProgressCallback,
104+
model.tokenizerSource,
105+
model.schedulerSource,
106+
model.encoderSource,
107+
model.unetSource,
108+
model.decoderSource
109+
);
110+
if (!results) {
111+
throw new RnExecutorchError(
112+
RnExecutorchErrorCode.DownloadInterrupted,
113+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
114+
);
115+
}
116+
const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] =
117+
results;
79118

80-
this.nativeModule = global.loadTextToImage(
81-
tokenizerPath,
82-
encoderPath,
83-
unetPath,
84-
decoderPath,
85-
schedulerConfig.beta_start,
86-
schedulerConfig.beta_end,
87-
schedulerConfig.num_train_timesteps,
88-
schedulerConfig.steps_offset
119+
if (
120+
!tokenizerPath ||
121+
!schedulerPath ||
122+
!encoderPath ||
123+
!unetPath ||
124+
!decoderPath
125+
) {
126+
throw new RnExecutorchError(
127+
RnExecutorchErrorCode.DownloadInterrupted,
128+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
89129
);
90-
} catch (error) {
91-
Logger.error('Load failed:', error);
92-
throw parseUnknownError(error);
93130
}
131+
132+
const response = await fetch('file://' + schedulerPath);
133+
const schedulerConfig = await response.json();
134+
135+
this.nativeModule = global.loadTextToImage(
136+
tokenizerPath,
137+
encoderPath,
138+
unetPath,
139+
decoderPath,
140+
schedulerConfig.beta_start,
141+
schedulerConfig.beta_end,
142+
schedulerConfig.num_train_timesteps,
143+
schedulerConfig.steps_offset
144+
);
94145
}
95146

96147
/**

packages/react-native-executorch/src/types/tti.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
import { RnExecutorchError } from '../errors/errorUtils';
22
import { ResourceSource } from '../types/common';
33

4+
/**
5+
* Union of all built-in Text-to-Image model names.
6+
*
7+
* @category Types
8+
*/
9+
export type TextToImageModelName =
10+
| 'bk-sdm-tiny-vpred-512'
11+
| 'bk-sdm-tiny-vpred-256';
12+
413
/**
514
* Configuration properties for the `useTextToImage` hook.
615
*
@@ -11,6 +20,11 @@ export interface TextToImageProps {
1120
* Object containing the required model sources for the diffusion pipeline.
1221
*/
1322
model: {
23+
/**
24+
* The built-in model name (e.g. `'bk-sdm-tiny-vpred-512'`). Used for telemetry and hook reload triggers.
25+
* Pass one of the pre-built TTI constants (e.g. `BK_SDM_TINY_VPRED_512`) to populate all required fields.
26+
*/
27+
modelName: TextToImageModelName;
1428
/** Source for the text tokenizer binary/config. */
1529
tokenizerSource: ResourceSource;
1630
/** Source for the diffusion scheduler binary/config. */

0 commit comments

Comments
 (0)