11import { ResourceFetcher } from '../../utils/ResourceFetcher' ;
22import { ResourceSource } from '../../types/common' ;
3+ import { TextToImageModelName } from '../../types/tti' ;
34import { BaseModule } from '../BaseModule' ;
45
56import { PNG } from 'pngjs/browser' ;
@@ -15,82 +16,132 @@ import { Logger } from '../../common/Logger';
1516export class TextToImageModule extends BaseModule {
1617 private inferenceCallback : ( stepIdx : number ) => void ;
1718
18- /**
19- * Creates a new instance of `TextToImageModule` with optional callback on inference step.
20- *
21- * @param inferenceCallback - Optional callback function that receives the current step index during inference.
22- */
23- constructor ( inferenceCallback ?: ( stepIdx : number ) => void ) {
19+ private constructor ( inferenceCallback ?: ( stepIdx : number ) => void ) {
2420 super ( ) ;
2521 this . inferenceCallback = ( stepIdx : number ) => {
2622 inferenceCallback ?.( stepIdx ) ;
2723 } ;
2824 }
2925
3026 /**
31- * Loads the model from specified resources .
27+ * Creates a Text to Image instance for a built-in model .
3228 *
33- * @param model - Object containing sources for tokenizer, scheduler, encoder, unet, and decoder.
34- * @param onDownloadProgressCallback - Optional callback to monitor download progress.
29+ * @param namedSources - An object specifying the model name, pipeline sources, and optional inference callback.
30+ * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
31+ * @returns A Promise resolving to a `TextToImageModule` instance.
32+ *
33+ * @example
34+ * ```ts
35+ * import { TextToImageModule, BK_SDM_TINY_VPRED_512 } from 'react-native-executorch';
36+ * const tti = await TextToImageModule.fromModelName(BK_SDM_TINY_VPRED_512);
37+ * ```
3538 */
36- async load (
37- model : {
39+ static async fromModelName (
40+ namedSources : {
41+ modelName : TextToImageModelName ;
3842 tokenizerSource : ResourceSource ;
3943 schedulerSource : ResourceSource ;
4044 encoderSource : ResourceSource ;
4145 unetSource : ResourceSource ;
4246 decoderSource : ResourceSource ;
47+ inferenceCallback ?: ( stepIdx : number ) => void ;
4348 } ,
44- onDownloadProgressCallback : ( progress : number ) => void = ( ) => { }
45- ) : Promise < void > {
49+ onDownloadProgress : ( progress : number ) => void = ( ) => { }
50+ ) : Promise < TextToImageModule > {
51+ const instance = new TextToImageModule ( namedSources . inferenceCallback ) ;
4652 try {
47- const results = await ResourceFetcher . fetch (
48- onDownloadProgressCallback ,
49- model . tokenizerSource ,
50- model . schedulerSource ,
51- model . encoderSource ,
52- model . unetSource ,
53- model . decoderSource
54- ) ;
55- if ( ! results ) {
56- throw new RnExecutorchError (
57- RnExecutorchErrorCode . DownloadInterrupted ,
58- 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
59- ) ;
60- }
61- const [ tokenizerPath , schedulerPath , encoderPath , unetPath , decoderPath ] =
62- results ;
53+ await instance . internalLoad ( namedSources , onDownloadProgress ) ;
54+ return instance ;
55+ } catch ( error ) {
56+ Logger . error ( 'Load failed:' , error ) ;
57+ throw parseUnknownError ( error ) ;
58+ }
59+ }
6360
64- if (
65- ! tokenizerPath ||
66- ! schedulerPath ||
67- ! encoderPath ||
68- ! unetPath ||
69- ! decoderPath
70- ) {
71- throw new RnExecutorchError (
72- RnExecutorchErrorCode . DownloadInterrupted ,
73- 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
74- ) ;
75- }
61+ /**
62+ * Creates a Text to Image instance with user-provided model binaries.
63+ * Use this when working with a custom-exported diffusion pipeline.
64+ * Internally uses `'custom'` as the model name for telemetry.
65+ *
66+ * @param sources - An object containing the pipeline source paths.
67+ * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
68+ * @param inferenceCallback - Optional callback triggered after each diffusion step.
69+ * @returns A Promise resolving to a `TextToImageModule` instance.
70+ */
71+ static fromCustomModel (
72+ sources : {
73+ tokenizerSource : ResourceSource ;
74+ schedulerSource : ResourceSource ;
75+ encoderSource : ResourceSource ;
76+ unetSource : ResourceSource ;
77+ decoderSource : ResourceSource ;
78+ } ,
79+ onDownloadProgress : ( progress : number ) => void = ( ) => { } ,
80+ inferenceCallback ?: ( stepIdx : number ) => void
81+ ) : Promise < TextToImageModule > {
82+ return TextToImageModule . fromModelName (
83+ {
84+ modelName : 'custom' as TextToImageModelName ,
85+ ...sources ,
86+ inferenceCallback,
87+ } ,
88+ onDownloadProgress
89+ ) ;
90+ }
7691
77- const response = await fetch ( 'file://' + schedulerPath ) ;
78- const schedulerConfig = await response . json ( ) ;
92+ private async internalLoad (
93+ model : {
94+ tokenizerSource : ResourceSource ;
95+ schedulerSource : ResourceSource ;
96+ encoderSource : ResourceSource ;
97+ unetSource : ResourceSource ;
98+ decoderSource : ResourceSource ;
99+ } ,
100+ onDownloadProgressCallback : ( progress : number ) => void
101+ ) : Promise < void > {
102+ const results = await ResourceFetcher . fetch (
103+ onDownloadProgressCallback ,
104+ model . tokenizerSource ,
105+ model . schedulerSource ,
106+ model . encoderSource ,
107+ model . unetSource ,
108+ model . decoderSource
109+ ) ;
110+ if ( ! results ) {
111+ throw new RnExecutorchError (
112+ RnExecutorchErrorCode . DownloadInterrupted ,
113+ 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
114+ ) ;
115+ }
116+ const [ tokenizerPath , schedulerPath , encoderPath , unetPath , decoderPath ] =
117+ results ;
79118
80- this . nativeModule = global . loadTextToImage (
81- tokenizerPath ,
82- encoderPath ,
83- unetPath ,
84- decoderPath ,
85- schedulerConfig . beta_start ,
86- schedulerConfig . beta_end ,
87- schedulerConfig . num_train_timesteps ,
88- schedulerConfig . steps_offset
119+ if (
120+ ! tokenizerPath ||
121+ ! schedulerPath ||
122+ ! encoderPath ||
123+ ! unetPath ||
124+ ! decoderPath
125+ ) {
126+ throw new RnExecutorchError (
127+ RnExecutorchErrorCode . DownloadInterrupted ,
128+ 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
89129 ) ;
90- } catch ( error ) {
91- Logger . error ( 'Load failed:' , error ) ;
92- throw parseUnknownError ( error ) ;
93130 }
131+
132+ const response = await fetch ( 'file://' + schedulerPath ) ;
133+ const schedulerConfig = await response . json ( ) ;
134+
135+ this . nativeModule = global . loadTextToImage (
136+ tokenizerPath ,
137+ encoderPath ,
138+ unetPath ,
139+ decoderPath ,
140+ schedulerConfig . beta_start ,
141+ schedulerConfig . beta_end ,
142+ schedulerConfig . num_train_timesteps ,
143+ schedulerConfig . steps_offset
144+ ) ;
94145 }
95146
96147 /**
0 commit comments