diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index 19e87ed956..f07211de51 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -66,3 +66,11 @@ softmax logit logits probs +unet +Unet +VPRED +timesteps +Timesteps +denoises +denoise +denoising \ No newline at end of file diff --git a/apps/computer-vision/app/_layout.tsx b/apps/computer-vision/app/_layout.tsx index 35a0cc0da4..5914d2fe8a 100644 --- a/apps/computer-vision/app/_layout.tsx +++ b/apps/computer-vision/app/_layout.tsx @@ -100,6 +100,14 @@ export default function _layout() { headerTitleStyle: { color: ColorPalette.primary }, }} /> + Style Transfer + router.navigate('text_to_image/')} + > + Image Generation + ); diff --git a/apps/computer-vision/app/text_to_image/index.tsx b/apps/computer-vision/app/text_to_image/index.tsx new file mode 100644 index 0000000000..14d394af1e --- /dev/null +++ b/apps/computer-vision/app/text_to_image/index.tsx @@ -0,0 +1,183 @@ +import { + View, + StyleSheet, + Text, + Image, + Keyboard, + TouchableWithoutFeedback, +} from 'react-native'; +import React, { useContext, useEffect, useState } from 'react'; +import Spinner from 'react-native-loading-spinner-overlay'; +import { useTextToImage, BK_SDM_TINY_VPRED_256 } from 'react-native-executorch'; +import { GeneratingContext } from '../../context'; +import ColorPalette from '../../colors'; +import ProgressBar from '../../components/ProgressBar'; +import { BottomBarWithTextInput } from '../../components/BottomBarWithTextInput'; + +export default function TextToImageScreen() { + const [inferenceStepIdx, setInferenceStepIdx] = useState(0); + const [imageTitle, setImageTitle] = useState(null); + const [image, setImage] = useState(null); + const [steps, setSteps] = useState(10); + const [showTextInput, setShowTextInput] = useState(false); + const [keyboardVisible, setKeyboardVisible] = useState(false); + + const imageSize = 224; + const model = useTextToImage({ + model: BK_SDM_TINY_VPRED_256, + inferenceCallback: (x) => setInferenceStepIdx(x), + }); + + const { setGlobalGenerating } = useContext(GeneratingContext); + + useEffect(() => { + setGlobalGenerating(model.isGenerating); + }, [model.isGenerating, setGlobalGenerating]); + + useEffect(() => { + const showSub = Keyboard.addListener('keyboardDidShow', () => { + setKeyboardVisible(true); + }); + const hideSub = Keyboard.addListener('keyboardDidHide', () => { + setKeyboardVisible(false); + }); + return () => { + showSub.remove(); + hideSub.remove(); + }; + }, []); + + const runForward = async (input: string, numSteps: number) => { + if (!input || !input.trim()) return; + const prevImageTitle = imageTitle; + setImageTitle(input); + setSteps(numSteps); + try { + const output = await model.generate(input, imageSize, steps); + if (!output.length) { + setImageTitle(prevImageTitle); + return; + } + setImage(output); + } catch (e) { + console.error(e); + setImageTitle(null); + } finally { + setInferenceStepIdx(0); + } + }; + + if (!model.isReady) { + // TODO: Update when #614 merged + return ( + + ); + } + + return ( + { + Keyboard.dismiss(); + setShowTextInput(false); + }} + > + + {keyboardVisible && } + + + {imageTitle && {imageTitle}} + + + {model.isGenerating ? ( + + Generating... + + + ) : ( + + {image?.length ? ( + + ) : ( + + )} + + )} + + + + + + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + width: '100%', + alignItems: 'center', + }, + overlay: { + ...StyleSheet.absoluteFillObject, + backgroundColor: 'rgba(0,0,0,0.65)', + zIndex: 5, + }, + titleContainer: { + alignItems: 'center', + marginTop: 20, + }, + titleText: { + color: ColorPalette.primary, + fontSize: 20, + fontWeight: 'bold', + marginBottom: 12, + textAlign: 'center', + }, + text: { + fontSize: 16, + color: '#000', + }, + imageContainer: { + flex: 1, + position: 'absolute', + top: 100, + alignItems: 'center', + justifyContent: 'center', + }, + image: { + width: 256, + height: 256, + marginVertical: 30, + resizeMode: 'contain', + }, + progressContainer: { + flex: 1, + justifyContent: 'center', + alignItems: 'center', + }, + bottomContainer: { + flex: 1, + width: '90%', + position: 'absolute', + bottom: 0, + marginBottom: 25, + zIndex: 10, + }, +}); diff --git a/apps/computer-vision/components/BottomBarWithTextInput.tsx b/apps/computer-vision/components/BottomBarWithTextInput.tsx new file mode 100644 index 0000000000..2ebf78a56c --- /dev/null +++ b/apps/computer-vision/components/BottomBarWithTextInput.tsx @@ -0,0 +1,165 @@ +import React, { useState } from 'react'; +import { + View, + Text, + TextInput, + TouchableOpacity, + StyleSheet, + KeyboardAvoidingView, + Platform, +} from 'react-native'; +import { Ionicons } from '@expo/vector-icons'; +import ColorPalette from '../colors'; + +interface BottomBarProps { + runModel: (input: string, numSteps: number) => void; + stopModel: () => void; + isGenerating?: boolean; + isReady?: boolean; + showTextInput: boolean; + setShowTextInput: React.Dispatch>; + keyboardVisible: boolean; +} + +export const BottomBarWithTextInput = ({ + runModel, + stopModel, + isGenerating, + isReady, + showTextInput, + setShowTextInput, + keyboardVisible, +}: BottomBarProps) => { + const [input, setInput] = useState(''); + const [numSteps, setNumSteps] = useState(10); + + const decreaseSteps = () => setNumSteps((prev) => Math.max(5, prev - 5)); + const increaseSteps = () => setNumSteps((prev) => Math.min(50, prev + 5)); + + if (!showTextInput) { + if (isGenerating) { + return ( + + Stop model + + ); + } else { + return ( + setShowTextInput(true)} + disabled={!isReady} + > + Run model + + ); + } + } + + return ( + + + + { + setShowTextInput(false); + setInput(''); + runModel(input, numSteps); + }} + disabled={!isReady || isGenerating} + > + + + + + + + Steps: {numSteps} + + + + - + + + + + + + + + ); +}; + +const styles = StyleSheet.create({ + container: { + alignItems: 'center', + }, + inputContainer: { + flexDirection: 'row', + alignItems: 'center', + justifyContent: 'center', + }, + input: { + flex: 1, + borderRadius: 6, + padding: 8, + marginRight: 8, + backgroundColor: '#fff', + color: '#000', + }, + stepsContainer: { + width: '100%', + flexDirection: 'row', + alignItems: 'center', + justifyContent: 'space-between', + marginTop: 10, + }, + stepsButtons: { + flexDirection: 'row', + }, + button: { + width: '100%', + height: 40, + justifyContent: 'center', + alignItems: 'center', + backgroundColor: ColorPalette.primary, + borderRadius: 8, + }, + buttonText: { + color: '#fff', + fontSize: 16, + textAlign: 'center', + }, + iconButton: { + marginHorizontal: 5, + width: 40, + }, + text: { + flex: 1, + fontSize: 16, + color: '#000', + }, + textWhite: { + color: '#fff', + }, +}); diff --git a/apps/computer-vision/components/ProgressBar.tsx b/apps/computer-vision/components/ProgressBar.tsx new file mode 100644 index 0000000000..fe2850de91 --- /dev/null +++ b/apps/computer-vision/components/ProgressBar.tsx @@ -0,0 +1,59 @@ +import React from 'react'; +import { View, StyleSheet } from 'react-native'; +import ColorPalette from '../colors'; + +type ProgressBarProps = { + numSteps: number; + currentStep: number; +}; + +export default function ProgressBar({ + numSteps, + currentStep, +}: ProgressBarProps) { + return ( + + {Array.from({ length: numSteps }).map((_, i) => ( + + ))} + + ); +} + +const styles = StyleSheet.create({ + progressBarContainer: { + flexDirection: 'row', + alignItems: 'center', + justifyContent: 'center', + marginVertical: 16, + width: '80%', + }, + progressStep: { + flex: 1, + height: 15, + }, + progressStepActive: { + backgroundColor: ColorPalette.primary, + }, + progressStepInactive: { + backgroundColor: '#e0e0e0', + }, + progressStepFirst: { + borderTopLeftRadius: 8, + borderBottomLeftRadius: 8, + }, + progressStepLast: { + borderTopRightRadius: 8, + borderBottomRightRadius: 8, + }, +}); diff --git a/apps/computer-vision/package.json b/apps/computer-vision/package.json index ba0094beee..aad49f93c9 100644 --- a/apps/computer-vision/package.json +++ b/apps/computer-vision/package.json @@ -40,6 +40,7 @@ }, "devDependencies": { "@babel/core": "^7.25.2", + "@types/pngjs": "^6.0.5", "@types/react": "~19.0.10" }, "private": true diff --git a/docs/docs/02-hooks/02-computer-vision/useImageEmbeddings.md b/docs/docs/02-hooks/02-computer-vision/useImageEmbeddings.md index 1849a95ceb..6dbdc7dcca 100644 --- a/docs/docs/02-hooks/02-computer-vision/useImageEmbeddings.md +++ b/docs/docs/02-hooks/02-computer-vision/useImageEmbeddings.md @@ -91,9 +91,9 @@ try { ## Supported models -| Model | Language | Image size | Embedding Dimensions | Description | +| Model | Language | Image size | Embedding dimensions | Description | | ---------------------------------------------------------------------------------- | :------: | :--------: | :------------------: | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [clip-vit-base-patch32-image](https://huggingface.co/openai/clip-vit-base-patch32) | English | 224 x 224 | 512 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. CLIP allows to embed images and text into the same vector space. This allows to find similar images as well as to implement image search. This is the image encoder part of the CLIP model. To embed text checkout [clip-vit-base-patch32-text](../01-natural-language-processing/useTextEmbeddings.md#supported-models). | +| [clip-vit-base-patch32-image](https://huggingface.co/openai/clip-vit-base-patch32) | English | 224×224 | 512 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. CLIP allows to embed images and text into the same vector space. This allows to find similar images as well as to implement image search. This is the image encoder part of the CLIP model. To embed text checkout [clip-vit-base-patch32-text](../01-natural-language-processing/useTextEmbeddings.md#supported-models). | **`Image size`** - the size of an image that the model takes as an input. Resize will happen automatically. diff --git a/docs/docs/02-hooks/02-computer-vision/useTextToImage.md b/docs/docs/02-hooks/02-computer-vision/useTextToImage.md new file mode 100644 index 0000000000..83e47a3e2d --- /dev/null +++ b/docs/docs/02-hooks/02-computer-vision/useTextToImage.md @@ -0,0 +1,133 @@ +--- +title: useTextToImage +keywords: [image generation] +description: "Learn how to use image generation models in your React Native applications with React Native ExecuTorch's useTextToImage hook." +--- + +Text-to-image is a process of generating images directly from a description in natural language by conditioning a model on the provided text input. Our implementation follows the Stable Diffusion pipeline, which applies the diffusion process in a lower-dimensional latent space to reduce memory requirements. The pipeline combines a text encoder to preprocess the prompt, a U-Net that iteratively denoises latent representations, and a VAE decoder to reconstruct the final image. React Native ExecuTorch offers a dedicated hook, `useTextToImage`, for this task. + + + +:::warning +It is recommended to use models provided by us which are available at our Hugging Face repository, you can also use [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) shipped with our library. +::: + +## Reference + +```typescript +import { useTextToImage, BK_SDM_TINY_VPRED_256 } from 'react-native-executorch'; + +const model = useTextToImage({ model: BK_SDM_TINY_VPRED_256 }); + +const input = 'a castle'; + +try { + const image = await model.generate(input); +} catch (error) { + console.error(error); +} +``` + +### Arguments + +**`model`** - Object containing the model source. + +- **`schedulerSource`** - A string that specifies the location of the scheduler config. + +- **`tokenizerSource`** - A string that specifies the location of the tokenizer config. + +- **`encoderSource`** - A string that specifies the location of the text encoder binary. + +- **`unetSource`** - A string that specifies the location of the U-Net binary. + +- **`decoderSource`** - A string that specifies the location of the VAE decoder binary. + +**`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + +### Returns + +| Field | Type | Description | +| ------------------ | ------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `generate` | `(input: string, imageSize?: number, numSteps?: number, seed?: number) => Promise` | Runs the model to generate an image described by `input`, and conditioned by `seed`, performing `numSteps` inference steps. The resulting image, with dimensions `imageSize`×`imageSize` pixels, is returned as a base64-encoded string. | +| `error` | string | null | Contains the error message if the model failed to load. | +| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | +| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | +| `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1. | +| `interrupt()` | `() => void` | Interrupts the current inference. The model is stopped in the nearest inference step. | + +## Running the model + +To run the model, you can use the `forward` method. It accepts four arguments: a text prompt describing the requested image, a size of the image in pixels, a number of denoising steps, and an optional seed value, which enables reproducibility of the results. + +The image size must be a multiple of 32 due to the architecture of the U-Net and VAE models. The seed should be a positive integer. + +:::warning +Larger imageSize values require significantly more memory to run the model. +::: + +## Example + +```tsx +import { useTextToImage, BK_SDM_TINY_VPRED_256 } from 'react-native-executorch'; + +function App() { + const model = useTextToImage({ model: BK_SDM_TINY_VPRED_256 }); + + //... + const input = 'a medieval castle by the sea shore'; + + const imageSize = 256; + const numSteps = 25; + + try { + image = await model.generate(input, imageSize, numSteps); + } catch (error) { + console.error(error); + } + //... + + return ; +} +``` + +| ![Castle 256x256](../../../static/img/castle256.png) | ![Castle 512x512](../../../static/img/castle512.png) | +| ---------------------------------------------------- | ---------------------------------------------------- | +| Image of size 256×256 | Image of size 512×512 | + +## Supported models + +| Model | Parameters [B] | Description | +| ------------------------------------------------------------------- | -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| [bk-sdm-tiny-vpred](https://huggingface.co/vivym/bk-sdm-tiny-vpred) | 0.5 | BK-SDM (Block-removed Knowledge-distilled Stable Diffusion Model) is a compressed version of Stable Diffusion v1.4 with several residual and attention blocks removed. The BK-SDM-Tiny is a v-prediction variant of the model, obtained through further block removal, built around a 0.33B-parameter U-Net. | + +## Benchmarks + +:::info +The number following the underscore (\_) indicates that the model supports generating image with dimensions ranging from 128 pixels up to that value. This setting doesn’t affect the model’s file size - it only determines how memory is allocated at runtime, based on the maximum allowed image size. +::: + +### Model size + +| Model | Text encoder (XNNPACK) [MB] | UNet (XNNPACK) [MB] | VAE decoder (XNNPACK) [MB] | +| --------------------- | --------------------------- | ------------------- | -------------------------- | +| BK_SDM_TINY_VPRED_256 | 492 | 1290 | 198 | +| BK_SDM_TINY_VPRED_512 | 492 | 1290 | 198 | + +### Memory usage + +| Model | Android (XNNPACK) [MB] | iOS (XNNPACK) [MB] | +| --------------------- | ---------------------- | ------------------ | +| BK_SDM_TINY_VPRED_256 | 2900 | 2800 | +| BK_SDM_TINY_VPRED_512 | 6700 | 6560 | + +### Inference time + +| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro Max (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) | Samsung Galaxy S24 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] | +| --------------------- | :--------------------------: | :------------------------------: | :-------------------: | :-------------------------------: | :-----------------------: | +| BK_SDM_TINY_VPRED_256 | 19100 | 25000 | ❌ | ❌ | 23100 | + +:::info +Text-to-image benchmark times are measured generating 256×256 images in 10 inference steps. +::: diff --git a/docs/docs/03-typescript-api/02-computer-vision/TextToImageModule.md b/docs/docs/03-typescript-api/02-computer-vision/TextToImageModule.md new file mode 100644 index 0000000000..aa5f67c8d1 --- /dev/null +++ b/docs/docs/03-typescript-api/02-computer-vision/TextToImageModule.md @@ -0,0 +1,82 @@ +--- +title: TextToImageModule +--- + +TypeScript API implementation of the [useTextToImage](../../02-hooks/02-computer-vision/useTextToImage.md) hook. + +## Reference + +```typescript +import { + TextToImageModule, + BK_SDM_TINY_VPRED_256, +} from 'react-native-executorch'; + +const input = 'a castle'; + +// Creating an instance +const textToImageModule = new TextToImageModule(); + +// Loading the model +await textToImageModule.load(BK_SDM_TINY_VPRED_256); + +// Running the model +const image = await textToImageModule.forward(input); +``` + +### Methods + +| Method | Type | Description | +| ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `constructor` | `(inferenceCallback?: (stepIdx: number) => void)` | Creates a new instance of TextToImageModule with optional callback on inference step. | +| `load` | `(model: {tokenizerSource: ResourceSource; schedulerSource: ResourceSource; encoderSource: ResourceSource; unetSource: ResourceSource; decoderSource: ResourceSource;}, onDownloadProgressCallback: (progress: number) => void): Promise` | Loads the model. | +| `forward` | `(input: string, imageSize: number, numSteps: number, seed?: number) => Promise` | Runs the model to generate an image described by `input`, and conditioned by `seed`, performing `numSteps` inference steps. The resulting image, with dimensions `imageSize`×`imageSize` pixels, is returned as a base64-encoded string. | +| `delete` | `() => void` | Deletes the model from memory. Note you cannot delete model while it's generating. You need to interrupt it first and make sure model stopped generation. | +| `interrupt` | `() => void` | Interrupts model generation. The model is stopped in the nearest step. | + +
+Type definitions + +```typescript +type ResourceSource = string | number | object; +``` + +
+ +## Loading the model + +To load the model, use the `load` method. It accepts an object: + +**`model`** - Object containing the model source. + +- **`schedulerSource`** - A string that specifies the location of the scheduler config. + +- **`tokenizerSource`** - A string that specifies the location of the tokenizer config. + +- **`encoderSource`** - A string that specifies the location of the text encoder binary. + +- **`unetSource`** - A string that specifies the location of the U-Net binary. + +- **`decoderSource`** - A string that specifies the location of the VAE decoder binary. + +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. + +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + +## Running the model + +To run the model, you can use the `forward` method. It accepts four arguments: a text prompt describing the requested image, a size of the image in pixels, a number of denoising steps, and an optional seed value, which enables reproducibility of the results. + +The image size must fall within the range from 128 to 512 unless specified differently, and be a multiple of 32 due to the architecture of the U-Net and VAE models. + +The seed value should be a positive integer. + +## Listening for inference steps + +To monitor the progress of image generation, you can pass an `inferenceCallback` function to the constructor. The callback is invoked at each denoising step (for a total of `numSteps + 1` times), yielding the current step index that can be used, for example, to display a progress bar. + +## Deleting the model from memory + +To delete the model from memory, you can use the `delete` method. diff --git a/docs/docs/04-benchmarks/inference-time.md b/docs/docs/04-benchmarks/inference-time.md index 504c0f6e9b..dd0f1275a7 100644 --- a/docs/docs/04-benchmarks/inference-time.md +++ b/docs/docs/04-benchmarks/inference-time.md @@ -66,27 +66,27 @@ Times presented in the tables are measured as consecutive runs of the model. Ini Notice than for `Whisper` model which has to take as an input 30 seconds audio chunks (for shorter audio it is automatically padded with silence to 30 seconds) `fast` mode has the lowest latency (time from starting transcription to first token returned, caused by streaming algorithm), but the slowest speed. If you believe that this might be a problem for you, prefer `balanced` mode instead. -| Model (mode) | iPhone 16 Pro (XNNPACK) [latency \| tokens/s] | iPhone 14 Pro (XNNPACK) [latency \| tokens/s] | iPhone SE 3 (XNNPACK) [latency \| tokens/s] | Samsung Galaxy S24 (XNNPACK) [latency \| tokens/s] | OnePlus 12 (XNNPACK) [latency \| tokens/s] | -| ------------------------- | :-------------------------------------------: | :-------------------------------------------: | :-----------------------------------------: | :------------------------------------------------: | :----------------------------------------: | -| Whisper-tiny (fast) | 2.8s \| 5.5t/s | 3.7s \| 4.4t/s | 4.4s \| 3.4t/s | 5.5s \| 3.1t/s | 5.3s \| 3.8t/s | -| Whisper-tiny (balanced) | 5.6s \| 7.9t/s | 7.0s \| 6.3t/s | 8.3s \| 5.0t/s | 8.4s \| 6.7t/s | 7.7s \| 7.2t/s | -| Whisper-tiny (quality) | 10.3s \| 8.3t/s | 12.6s \| 6.8t/s | 7.8s \| 8.9t/s | 13.5s \| 7.1t/s | 12.9s \| 7.5t/s | +| Model (mode) | iPhone 16 Pro (XNNPACK) [latency \| tokens/s] | iPhone 14 Pro (XNNPACK) [latency \| tokens/s] | iPhone SE 3 (XNNPACK) [latency \| tokens/s] | Samsung Galaxy S24 (XNNPACK) [latency \| tokens/s] | OnePlus 12 (XNNPACK) [latency \| tokens/s] | +| ----------------------- | :-------------------------------------------: | :-------------------------------------------: | :-----------------------------------------: | :------------------------------------------------: | :----------------------------------------: | +| Whisper-tiny (fast) | 2.8s \| 5.5t/s | 3.7s \| 4.4t/s | 4.4s \| 3.4t/s | 5.5s \| 3.1t/s | 5.3s \| 3.8t/s | +| Whisper-tiny (balanced) | 5.6s \| 7.9t/s | 7.0s \| 6.3t/s | 8.3s \| 5.0t/s | 8.4s \| 6.7t/s | 7.7s \| 7.2t/s | +| Whisper-tiny (quality) | 10.3s \| 8.3t/s | 12.6s \| 6.8t/s | 7.8s \| 8.9t/s | 13.5s \| 7.1t/s | 12.9s \| 7.5t/s | ### Encoding Average time for encoding audio of given length over 10 runs. For `Whisper` model we only list 30 sec audio chunks since `Whisper` does not accept other lengths (for shorter audio the audio needs to be padded to 30sec with silence). -| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) [ms] | Samsung Galaxy S24 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] | -| -------------------- | :--------------------------: | :--------------------------: | :------------------------: | :-------------------------------: | :-----------------------: | -| Whisper-tiny (30s) | 1034 | 1344 | 1269 | 2916 | 2143 | +| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) [ms] | Samsung Galaxy S24 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] | +| ------------------ | :--------------------------: | :--------------------------: | :------------------------: | :-------------------------------: | :-----------------------: | +| Whisper-tiny (30s) | 1034 | 1344 | 1269 | 2916 | 2143 | ### Decoding Average time for decoding one token in sequence of 100 tokens, with encoding context is obtained from audio of noted length. -| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) [ms] | Samsung Galaxy S24 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] | -| -------------------- | :--------------------------: | :--------------------------: | :------------------------: | :-------------------------------: | :-----------------------: | -| Whisper-tiny (30s) | 128.03 | 113.65 | 141.63 | 89.08 | 84.49 | +| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) [ms] | Samsung Galaxy S24 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] | +| ------------------ | :--------------------------: | :--------------------------: | :------------------------: | :-------------------------------: | :-----------------------: | +| Whisper-tiny (30s) | 128.03 | 113.65 | 141.63 | 89.08 | 84.49 | ## Text Embeddings @@ -111,3 +111,11 @@ Benchmark times for text embeddings are highly dependent on the sentence length. :::info Image embedding benchmark times are measured using 224×224 pixel images, as required by the model. All input images, whether larger or smaller, are resized to 224×224 before processing. Resizing is typically fast for small images but may be noticeably slower for very large images, which can increase total inference time. ::: + +## Text to Image + +Average time for generating one image of size 256×256 in 10 inference steps. + +| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro Max (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) | Samsung Galaxy S24 (XNNPACK) [ms] | OnePlus 12 (XNNPACK) [ms] | +| --------------------- | :--------------------------: | :------------------------------: | :-------------------: | :-------------------------------: | :-----------------------: | +| BK_SDM_TINY_VPRED_256 | 19100 | 25000 | ❌ | ❌ | 23100 | diff --git a/docs/docs/04-benchmarks/memory-usage.md b/docs/docs/04-benchmarks/memory-usage.md index 684020e2af..e34c8a7ca9 100644 --- a/docs/docs/04-benchmarks/memory-usage.md +++ b/docs/docs/04-benchmarks/memory-usage.md @@ -68,3 +68,10 @@ title: Memory Usage | Model | Android (XNNPACK) [MB] | iOS (XNNPACK) [MB] | | --------------------------- | :--------------------: | :----------------: | | CLIP_VIT_BASE_PATCH32_IMAGE | 350 | 340 | + +## Text to Image + +| Model | Android (XNNPACK) [MB] | iOS (XNNPACK) [MB] | +| --------------------- | ---------------------- | ------------------ | +| BK_SDM_TINY_VPRED_256 | 2900 | 2800 | +| BK_SDM_TINY_VPRED | 6700 | 6560 | diff --git a/docs/docs/04-benchmarks/model-size.md b/docs/docs/04-benchmarks/model-size.md index 9d20c95d5b..5cf87f6faa 100644 --- a/docs/docs/04-benchmarks/model-size.md +++ b/docs/docs/04-benchmarks/model-size.md @@ -82,3 +82,9 @@ title: Model Size | Model | XNNPACK [MB] | | --------------------------- | :----------: | | CLIP_VIT_BASE_PATCH32_IMAGE | 352 | + +## Text to Image + +| Model | Text encoder (XNNPACK) [MB] | UNet (XNNPACK) [MB] | VAE decoder (XNNPACK) [MB] | +| ----------------- | --------------------------- | ------------------- | -------------------------- | +| BK_SDM_TINY_VPRED | 492 | 1290 | 198 | diff --git a/docs/static/img/castle256.png b/docs/static/img/castle256.png new file mode 100644 index 0000000000..8f6197066d Binary files /dev/null and b/docs/static/img/castle256.png differ diff --git a/docs/static/img/castle512.png b/docs/static/img/castle512.png new file mode 100644 index 0000000000..d7607c1fde Binary files /dev/null and b/docs/static/img/castle512.png differ diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index 70663c239c..ed0d37f927 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +43,11 @@ void RnExecutorchInstaller::injectJSIBindings( models::image_segmentation::ImageSegmentation>( jsiRuntime, jsCallInvoker, "loadImageSegmentation")); + jsiRuntime->global().setProperty( + *jsiRuntime, "loadTextToImage", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadTextToImage")); + jsiRuntime->global().setProperty( *jsiRuntime, "loadClassification", RnExecutorchInstaller::loadModel( @@ -49,9 +55,8 @@ void RnExecutorchInstaller::injectJSIBindings( jsiRuntime->global().setProperty( *jsiRuntime, "loadObjectDetection", - RnExecutorchInstaller::loadModel< - models::object_detection::ObjectDetection>(jsiRuntime, jsCallInvoker, - "loadObjectDetection")); + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadObjectDetection")); jsiRuntime->global().setProperty( *jsiRuntime, "loadExecutorchModule", diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index 4caf5c1328..f512dce9d7 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -19,6 +19,7 @@ #include #include #include +#include #include namespace rnexecutorch { @@ -117,6 +118,12 @@ template class ModelHostObject : public JsiHostObject { JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); } + if constexpr (meta::SameAs) { + addFunctions(JSI_EXPORT_FUNCTION( + ModelHostObject, synchronousHostFunction<&Model::interrupt>, + "interrupt")); + } + if constexpr (meta::SameAs) { addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); diff --git a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp index 2f7bf69457..c452aa331d 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp @@ -48,7 +48,6 @@ TextEmbeddings::generate(const std::string input) { attnMaskShape, preprocessed.attentionMask.data(), ScalarType::Long); auto forwardResult = BaseModel::forward({tokenIds, attnMask}); - if (!forwardResult.ok()) { throw std::runtime_error( "Function forward in TextEmbeddings failed with error code: " + diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Constants.h new file mode 100644 index 0000000000..c6af48e82e --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Constants.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +namespace rnexecutorch::models::text_to_image::constants { + +inline constexpr std::string_view kBosToken = "<|startoftext|>"; + +} // namespace rnexecutorch::models::text_to_image::constants diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.cpp new file mode 100644 index 0000000000..ceef511753 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.cpp @@ -0,0 +1,32 @@ +#include "Decoder.h" + +#include + +#include + +namespace rnexecutorch::models::text_to_image { + +using namespace executorch::extension; + +Decoder::Decoder(const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) {} + +std::vector Decoder::generate(std::vector &input) const { + std::vector inputShape = {1, numChannels, latentImageSize, + latentImageSize}; + auto inputTensor = + make_tensor_ptr(inputShape, input.data(), ScalarType::Float); + + auto forwardResult = BaseModel::forward(inputTensor); + if (!forwardResult.ok()) { + throw std::runtime_error( + "Function forward in decoder failed with error code: " + + std::to_string(static_cast(forwardResult.error()))); + } + + auto forwardResultTensor = forwardResult->at(0).toTensor(); + const auto *dataPtr = forwardResultTensor.const_data_ptr(); + return {dataPtr, dataPtr + forwardResultTensor.numel()}; +} +} // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.h new file mode 100644 index 0000000000..c0b35c102a --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Decoder.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include + +#include + +#include + +namespace rnexecutorch::models::text_to_image { + +class Decoder final : public BaseModel { +public: + explicit Decoder(const std::string &modelSource, + std::shared_ptr callInvoker); + std::vector generate(std::vector &input) const; + + int32_t latentImageSize; + +private: + static constexpr int32_t numChannels = 4; +}; +} // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp new file mode 100644 index 0000000000..68a9a9fef4 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.cpp @@ -0,0 +1,44 @@ +#include "Encoder.h" + +#include +#include +#include + +#include + +namespace rnexecutorch::models::text_to_image { + +Encoder::Encoder(const std::string &tokenizerSource, + const std::string &encoderSource, + std::shared_ptr callInvoker) + : callInvoker(callInvoker), + encoder(std::make_unique( + encoderSource, tokenizerSource, callInvoker)) {} + +std::vector Encoder::generate(std::string input) { + std::shared_ptr embeddingsText = encoder->generate(input); + std::shared_ptr embeddingsUncond = + encoder->generate(std::string(constants::kBosToken)); + + assert(embeddingsText->size() == embeddingsUncond->size()); + size_t embeddingsSize = embeddingsText->size() / sizeof(float); + auto *embeddingsTextPtr = reinterpret_cast(embeddingsText->data()); + auto *embeddingsUncondPtr = + reinterpret_cast(embeddingsUncond->data()); + + std::vector embeddingsConcat; + embeddingsConcat.reserve(embeddingsSize * 2); + embeddingsConcat.insert(embeddingsConcat.end(), embeddingsUncondPtr, + embeddingsUncondPtr + embeddingsSize); + embeddingsConcat.insert(embeddingsConcat.end(), embeddingsTextPtr, + embeddingsTextPtr + embeddingsSize); + return embeddingsConcat; +} + +size_t Encoder::getMemoryLowerBound() const noexcept { + return encoder->getMemoryLowerBound(); +} + +void Encoder::unload() noexcept { encoder->unload(); } + +} // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.h new file mode 100644 index 0000000000..b444f30ab9 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Encoder.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include + +#include + +namespace rnexecutorch { +namespace models::text_to_image { +using namespace facebook; + +class Encoder final { +public: + explicit Encoder(const std::string &tokenizerSource, + const std::string &encoderSource, + std::shared_ptr callInvoker); + std::vector generate(std::string input); + size_t getMemoryLowerBound() const noexcept; + void unload() noexcept; + +private: + std::shared_ptr callInvoker; + std::unique_ptr encoder; +}; +} // namespace models::text_to_image +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.cpp new file mode 100644 index 0000000000..e20bf60bba --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.cpp @@ -0,0 +1,152 @@ +// The implementation of the PNDMScheduler class from the diffusers library +// https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py + +#include "Scheduler.h" + +#include +#include + +namespace rnexecutorch::models::text_to_image { +using namespace facebook; + +Scheduler::Scheduler(float betaStart, float betaEnd, int32_t numTrainTimesteps, + int32_t stepsOffset, + std::shared_ptr callInvoker) + : numTrainTimesteps(numTrainTimesteps), stepsOffset(stepsOffset) { + const float start = std::sqrt(betaStart); + const float end = std::sqrt(betaEnd); + const float step = (end - start) / (numTrainTimesteps - 1); + + float runningProduct = 1.0f; + alphas.reserve(numTrainTimesteps); + // alphasCumprod[t] — fraction of the signal remaining after t steps + alphasCumprod.reserve(numTrainTimesteps); + // betas[t] — amount of noise injected at timestep t + betas.reserve(numTrainTimesteps); + for (int32_t i = 0; i < numTrainTimesteps; ++i) { + const float value = start + step * i; + const float beta = value * value; + betas.push_back(beta); + + const float alpha = 1.0f - beta; + alphas.push_back(alpha); + + runningProduct *= alpha; + alphasCumprod.push_back(runningProduct); + } + + // finalAlphaCumprod — signal at the first training step (highest + // signal-to-noise ratio) used as reference at the end of diffusion process + if (!alphasCumprod.empty()) { + finalAlphaCumprod = alphasCumprod[0]; + } +} + +void Scheduler::setTimesteps(size_t numInferenceSteps) { + this->numInferenceSteps = numInferenceSteps; + ets.clear(); + + if (numInferenceSteps < 2) { + timesteps = {1}; + return; + } + + timesteps.clear(); + timesteps.reserve(numInferenceSteps + 1); + + float numStepsRatio = + static_cast(numTrainTimesteps) / numInferenceSteps; + for (size_t i = 0; i < numInferenceSteps; i++) { + const auto timestep = + static_cast(std::round(i * numStepsRatio)) + stepsOffset; + timesteps.push_back(timestep); + } + // Duplicate the timestep to provide enough points for the solver + timesteps.insert(timesteps.end() - 1, timesteps[numInferenceSteps - 2]); + std::ranges::reverse(timesteps); +} + +std::vector Scheduler::step(const std::vector &sample, + const std::vector &noise, + int32_t timestep) { + if (numInferenceSteps == 0) { + throw std::runtime_error( + "Number of inference steps is not set. Call `set_timesteps` first."); + } + + size_t noiseSize = noise.size(); + std::vector etsOutput(noiseSize); + float numStepsRatio = + static_cast(numTrainTimesteps) / numInferenceSteps; + float timestepPrev = timestep - numStepsRatio; + + if (ets.empty()) { + ets.push_back(noise); + etsOutput = noise; + tempFirstSample = sample; + return getPrevSample(sample, etsOutput, timestep, timestepPrev); + } + + // Use the previous sample as the estimate requires at least 2 points + if (ets.size() == 1 && !tempFirstSample.empty()) { + for (size_t i = 0; i < noiseSize; i++) { + etsOutput[i] = (noise[i] + ets[0][i]) / 2; + } + auto prevSample = getPrevSample(std::move(tempFirstSample), etsOutput, + timestep + numStepsRatio, timestep); + tempFirstSample.clear(); + return prevSample; + } + + // Coefficients come from the linear multistep method + // https://en.wikipedia.org/wiki/Linear_multistep_method + ets.push_back(noise); + + if (ets.size() == 2) { + for (size_t i = 0; i < noiseSize; i++) { + etsOutput[i] = (ets[1][i] * 3 - ets[0][i]) / 2; + } + } else if (ets.size() == 3) { + for (size_t i = 0; i < noiseSize; i++) { + etsOutput[i] = ((ets[2][i] * 23 - ets[1][i] * 16) + ets[0][i] * 5) / 12; + } + } else { + ets.assign(ets.end() - 4, ets.end()); + for (size_t i = 0; i < noiseSize; i++) { + etsOutput[i] = + (ets[3][i] * 55 - ets[2][i] * 59 + ets[1][i] * 37 - ets[0][i] * 9) / + 24; + } + } + return getPrevSample(sample, etsOutput, timestep, timestepPrev); +} + +std::vector Scheduler::getPrevSample(const std::vector &sample, + const std::vector &noise, + int32_t timestep, + int32_t timestepPrev) const { + const float alpha = alphasCumprod[timestep]; + const float alphaPrev = + timestepPrev >= 0 ? alphasCumprod[timestepPrev] : finalAlphaCumprod; + const float beta = 1 - alpha; + const float betaPrev = 1 - alphaPrev; + + size_t noiseSize = noise.size(); + const float noiseCoeff = + (alphaPrev - alpha) / + (alpha * std::sqrt(betaPrev) + std::sqrt(alpha * beta * alphaPrev)); + const float sampleCoeff = std::sqrt(alphaPrev / alpha); + + std::vector samplePrev; + samplePrev.reserve(noiseSize); + for (size_t i = 0; i < noiseSize; i++) { + const float noiseTerm = + (noise[i] * std::sqrt(alpha) + sample[i] * std::sqrt(beta)) * + noiseCoeff; + samplePrev.push_back(sample[i] * sampleCoeff - noiseTerm); + } + + return samplePrev; +} + +} // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.h new file mode 100644 index 0000000000..99ef9fb0c6 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/Scheduler.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include + +#include + +namespace rnexecutorch::models::text_to_image { + +using namespace facebook; + +class Scheduler final { +public: + explicit Scheduler(float betaStart, float betaEnd, int32_t numTrainTimesteps, + int32_t stepsOfset, + std::shared_ptr callInvoker); + void setTimesteps(size_t numInferenceSteps); + std::vector step(const std::vector &sample, + const std::vector &noise, int32_t timestep); + + std::vector timesteps; + +private: + int32_t numTrainTimesteps; + int32_t stepsOffset; + + std::vector betas; + std::vector alphas; + std::vector alphasCumprod; + std::vector tempFirstSample; + std::vector> ets; + float finalAlphaCumprod{1.0f}; + float initNoiseSigma{1.0f}; + + size_t numInferenceSteps{0}; + + std::vector getPrevSample(const std::vector &sample, + const std::vector &noise, + int32_t timestep, + int32_t prevTimestep) const; +}; +} // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp new file mode 100644 index 0000000000..0fd7a717af --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp @@ -0,0 +1,142 @@ +#include "TextToImage.h" + +#include +#include +#include + +#include + +#include +#include + +namespace rnexecutorch::models::text_to_image { + +using namespace executorch::extension; + +TextToImage::TextToImage(const std::string &tokenizerSource, + const std::string &encoderSource, + const std::string &unetSource, + const std::string &decoderSource, + float schedulerBetaStart, float schedulerBetaEnd, + int32_t schedulerNumTrainTimesteps, + int32_t schedulerStepsOffset, + std::shared_ptr callInvoker) + : callInvoker(callInvoker), + scheduler(std::make_unique( + schedulerBetaStart, schedulerBetaEnd, schedulerNumTrainTimesteps, + schedulerStepsOffset, callInvoker)), + encoder(std::make_unique(tokenizerSource, encoderSource, + callInvoker)), + unet(std::make_unique(unetSource, callInvoker)), + decoder(std::make_unique(decoderSource, callInvoker)) {} + +void TextToImage::setImageSize(int32_t imageSize) { + if (imageSize % 32 != 0) { + throw std::runtime_error("Image size must be a multiple of 32."); + } + this->imageSize = imageSize; + constexpr int32_t latentDownsample = 8; + latentImageSize = std::floor(imageSize / latentDownsample); + unet->latentImageSize = latentImageSize; + decoder->latentImageSize = latentImageSize; +} + +void TextToImage::setSeed(int32_t &seed) { + // Seed argument is provided + if (seed >= 0) { + return; + } + std::random_device rd; + seed = rd(); +} + +std::shared_ptr +TextToImage::generate(std::string input, int32_t imageSize, + size_t numInferenceSteps, int32_t seed, + std::shared_ptr callback) { + setImageSize(imageSize); + setSeed(seed); + + std::vector embeddings = encoder->generate(input); + std::vector embeddingsShape = {2, 77, 768}; + auto embeddingsTensor = + make_tensor_ptr(embeddingsShape, embeddings.data(), ScalarType::Float); + + constexpr int32_t latentDownsample = 8; + int32_t latentsSize = numChannels * latentImageSize * latentImageSize; + std::vector latents(latentsSize); + std::mt19937 gen(seed); + std::normal_distribution dist(0.0, 1.0); + for (auto &val : latents) { + val = dist(gen); + } + + scheduler->setTimesteps(numInferenceSteps); + std::vector timesteps = scheduler->timesteps; + + auto nativeCallback = [this, callback](size_t stepIdx) { + this->callInvoker->invokeAsync([callback, stepIdx](jsi::Runtime &runtime) { + callback->call(runtime, jsi::Value(static_cast(stepIdx))); + }); + }; + for (size_t t = 0; t < numInferenceSteps + 1 && !interrupted; t++) { + log(LOG_LEVEL::Debug, "Step:", t, "/", numInferenceSteps); + + std::vector noisePred = + unet->generate(latents, timesteps[t], embeddingsTensor); + + size_t noiseSize = noisePred.size() / 2; + std::span noisePredSpan{noisePred}; + std::span noiseUncond = noisePredSpan.subspan(0, noiseSize); + std::span noiseText = + noisePredSpan.subspan(noiseSize, noiseSize); + std::vector noise(noiseSize); + for (size_t i = 0; i < noiseSize; i++) { + noise[i] = + noiseUncond[i] * (1 - guidanceScale) + noiseText[i] * guidanceScale; + } + latents = scheduler->step(latents, noise, timesteps[t]); + + nativeCallback(t); + } + if (interrupted) { + interrupted = false; + return std::make_shared(0); + } + + for (auto &val : latents) { + val /= latentsScale; + } + + std::vector output = decoder->generate(latents); + return postprocess(output); +} + +std::shared_ptr +TextToImage::postprocess(const std::vector &output) const { + // Convert RGB to RGBA + int32_t imagePixelCount = imageSize * imageSize; + std::vector outputRgba(imagePixelCount * 4); + for (int32_t i = 0; i < imagePixelCount; i++) { + outputRgba[i * 4 + 0] = output[i * 3 + 0]; + outputRgba[i * 4 + 1] = output[i * 3 + 1]; + outputRgba[i * 4 + 2] = output[i * 3 + 2]; + outputRgba[i * 4 + 3] = 255; + } + return std::make_shared(outputRgba); +} + +void TextToImage::interrupt() noexcept { interrupted = true; } + +size_t TextToImage::getMemoryLowerBound() const noexcept { + return encoder->getMemoryLowerBound() + unet->getMemoryLowerBound() + + decoder->getMemoryLowerBound(); +} + +void TextToImage::unload() noexcept { + encoder->unload(); + unet->unload(); + decoder->unload(); +} + +} // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h new file mode 100644 index 0000000000..18316217cd --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include + +namespace rnexecutorch { +namespace models::text_to_image { +using namespace facebook; + +class TextToImage final { +public: + explicit TextToImage(const std::string &tokenizerSource, + const std::string &encoderSource, + const std::string &unetSource, + const std::string &decoderSource, + float schedulerBetaStart, float schedulerBetaEnd, + int32_t schedulerNumTrainTimesteps, + int32_t schedulerStepsOffset, + std::shared_ptr callInvoker); + std::shared_ptr + generate(std::string input, int32_t imageSize, size_t numInferenceSteps, + int32_t seed, std::shared_ptr callback); + void interrupt() noexcept; + size_t getMemoryLowerBound() const noexcept; + void unload() noexcept; + +private: + void setImageSize(int32_t imageSize); + void setSeed(int32_t &seed); + std::shared_ptr + postprocess(const std::vector &output) const; + + size_t memorySizeLowerBound; + int32_t imageSize; + int32_t latentImageSize; + static constexpr int32_t numChannels = 4; + static constexpr float guidanceScale = 7.5f; + static constexpr float latentsScale = 0.18215f; + bool interrupted = false; + + std::shared_ptr callInvoker; + std::unique_ptr scheduler; + std::unique_ptr encoder; + std::unique_ptr unet; + std::unique_ptr decoder; +}; +} // namespace models::text_to_image + +REGISTER_CONSTRUCTOR(models::text_to_image::TextToImage, std::string, + std::string, std::string, std::string, float, float, + int32_t, int32_t, std::shared_ptr); +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.cpp b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.cpp new file mode 100644 index 0000000000..d5e7e97458 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.cpp @@ -0,0 +1,38 @@ +#include "UNet.h" + +namespace rnexecutorch::models::text_to_image { + +using namespace executorch::extension; + +UNet::UNet(const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) {} + +std::vector UNet::generate(std::vector &latents, int32_t timestep, + TensorPtr &embeddingsTensor) const { + std::vector latentsConcat; + latentsConcat.reserve(2 * latentImageSize); + latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end()); + latentsConcat.insert(latentsConcat.end(), latents.begin(), latents.end()); + + std::vector latentsShape = {2, numChannels, latentImageSize, + latentImageSize}; + + auto timestepTensor = + make_tensor_ptr({static_cast(timestep)}); + auto latentsTensor = + make_tensor_ptr(latentsShape, latentsConcat.data(), ScalarType::Float); + + auto forwardResult = + BaseModel::forward({latentsTensor, timestepTensor, embeddingsTensor}); + if (!forwardResult.ok()) { + throw std::runtime_error( + "Function forward in UNet failed with error code: " + + std::to_string(static_cast(forwardResult.error()))); + } + + auto forwardResultTensor = forwardResult->at(0).toTensor(); + const auto *dataPtr = forwardResultTensor.const_data_ptr(); + return {dataPtr, dataPtr + forwardResultTensor.numel()}; +} +} // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.h b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.h new file mode 100644 index 0000000000..0c6dd057c6 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/text_to_image/UNet.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +#include + +#include +#include + +namespace rnexecutorch::models::text_to_image { + +using namespace executorch::extension; + +class UNet final : public BaseModel { +public: + explicit UNet(const std::string &modelSource, + std::shared_ptr callInvoker); + std::vector generate(std::vector &latents, int32_t timestep, + TensorPtr &embeddingsTensor) const; + + int32_t latentImageSize; + +private: + static constexpr int32_t numChannels = 4; +}; +} // namespace rnexecutorch::models::text_to_image diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index c1b5c0133b..34834733eb 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -422,3 +422,20 @@ export const CLIP_VIT_BASE_PATCH32_TEXT = { modelSource: CLIP_VIT_BASE_PATCH32_TEXT_MODEL, tokenizerSource: CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER, }; + +// Image generation +export const BK_SDM_TINY_VPRED_512 = { + schedulerSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/scheduler/scheduler_config.json`, + tokenizerSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/tokenizer/tokenizer.json`, + encoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/text_encoder/model.pte`, + unetSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/unet/model.pte`, + decoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/vae/model.pte`, +}; + +export const BK_SDM_TINY_VPRED_256 = { + schedulerSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/scheduler/scheduler_config.json`, + tokenizerSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/tokenizer/tokenizer.json`, + encoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/text_encoder/model.pte`, + unetSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/unet/model.256.pte`, + decoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/vae/model.256.pte`, +}; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts b/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts new file mode 100644 index 0000000000..b6a514d225 --- /dev/null +++ b/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts @@ -0,0 +1,92 @@ +import { useCallback, useEffect, useState } from 'react'; +import { ETError, getError } from '../../Error'; +import { ResourceSource } from '../../types/common'; +import { TextToImageModule } from '../../modules/computer_vision/TextToImageModule'; + +interface TextToImageType { + isReady: boolean; + isGenerating: boolean; + downloadProgress: number; + error: string | null; + generate: ( + input: string, + imageSize?: number, + numSteps?: number, + seed?: number + ) => Promise; + interrupt: () => void; +} + +export const useTextToImage = ({ + model, + inferenceCallback, + preventLoad = false, +}: { + model: { + tokenizerSource: ResourceSource; + schedulerSource: ResourceSource; + encoderSource: ResourceSource; + unetSource: ResourceSource; + decoderSource: ResourceSource; + }; + inferenceCallback?: (stepIdx: number) => void; + preventLoad?: boolean; +}): TextToImageType => { + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + const [error, setError] = useState(null); + + const [module] = useState(() => new TextToImageModule(inferenceCallback)); + + useEffect(() => { + if (preventLoad) return; + + (async () => { + setDownloadProgress(0); + setError(null); + try { + setIsReady(false); + await module.load(model, setDownloadProgress); + setIsReady(true); + } catch (err) { + setError((err as Error).message); + } + })(); + + return () => { + module.delete(); + }; + }, [module, model, preventLoad]); + + const generate = async ( + input: string, + imageSize?: number, + numSteps?: number, + seed?: number + ): Promise => { + if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); + if (isGenerating) throw new Error(getError(ETError.ModelGenerating)); + try { + setIsGenerating(true); + return await module.forward(input, imageSize, numSteps, seed); + } finally { + setIsGenerating(false); + } + }; + + const interrupt = useCallback(() => { + if (isGenerating) { + module.interrupt(); + } + }, [module, isGenerating]); + + return { + isReady, + isGenerating, + downloadProgress, + error, + generate, + interrupt, + }; +}; diff --git a/packages/react-native-executorch/src/index.ts b/packages/react-native-executorch/src/index.ts index 3ad5692ce5..6230221139 100644 --- a/packages/react-native-executorch/src/index.ts +++ b/packages/react-native-executorch/src/index.ts @@ -11,6 +11,16 @@ declare global { var loadImageEmbeddings: (source: string) => any; var loadTextEmbeddings: (modelSource: string, tokenizerSource: string) => any; var loadLLM: (modelSource: string, tokenizerSource: string) => any; + var loadTextToImage: ( + tokenizerSource: string, + encoderSource: string, + unetSource: string, + decoderSource: string, + schedulerBetaStart: number, + schedulerBetaEnd: number, + schedulerNumTrainTimesteps: number, + schedulerStepsOffset: number + ) => any; var loadSpeechToText: ( encoderSource: string, decoderSource: string, @@ -35,6 +45,7 @@ declare global { if ( global.loadStyleTransfer == null || global.loadImageSegmentation == null || + global.loadTextToImage == null || global.loadExecutorchModule == null || global.loadClassification == null || global.loadObjectDetection == null || @@ -62,6 +73,7 @@ export * from './hooks/computer_vision/useImageSegmentation'; export * from './hooks/computer_vision/useOCR'; export * from './hooks/computer_vision/useVerticalOCR'; export * from './hooks/computer_vision/useImageEmbeddings'; +export * from './hooks/computer_vision/useTextToImage'; export * from './hooks/natural_language_processing/useLLM'; export * from './hooks/natural_language_processing/useSpeechToText'; @@ -77,14 +89,16 @@ export * from './modules/computer_vision/StyleTransferModule'; export * from './modules/computer_vision/ImageSegmentationModule'; export * from './modules/computer_vision/OCRModule'; export * from './modules/computer_vision/VerticalOCRModule'; -export * from './modules/general/ExecutorchModule'; export * from './modules/computer_vision/ImageEmbeddingsModule'; +export * from './modules/computer_vision/TextToImageModule'; export * from './modules/natural_language_processing/LLMModule'; export * from './modules/natural_language_processing/SpeechToTextModule'; export * from './modules/natural_language_processing/TextEmbeddingsModule'; export * from './modules/natural_language_processing/TokenizerModule'; +export * from './modules/general/ExecutorchModule'; + // utils export * from './utils/ResourceFetcher'; export * from './utils/llm'; diff --git a/packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts b/packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts new file mode 100644 index 0000000000..cab509667a --- /dev/null +++ b/packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts @@ -0,0 +1,93 @@ +import { ResourceFetcher } from '../../utils/ResourceFetcher'; +import { ResourceSource } from '../../types/common'; +import { BaseModule } from '../BaseModule'; +import { Buffer } from 'buffer'; +import { PNG } from 'pngjs/browser'; + +export class TextToImageModule extends BaseModule { + private inferenceCallback: (stepIdx: number) => void; + + constructor(inferenceCallback?: (stepIdx: number) => void) { + super(); + this.inferenceCallback = (stepIdx: number) => { + inferenceCallback?.(stepIdx); + }; + } + + async load( + model: { + tokenizerSource: ResourceSource; + schedulerSource: ResourceSource; + encoderSource: ResourceSource; + unetSource: ResourceSource; + decoderSource: ResourceSource; + }, + onDownloadProgressCallback: (progress: number) => void = () => {} + ): Promise { + const results = await ResourceFetcher.fetch( + onDownloadProgressCallback, + model.tokenizerSource, + model.schedulerSource, + model.encoderSource, + model.unetSource, + model.decoderSource + ); + if (!results) { + throw new Error('Failed to fetch one or more resources.'); + } + const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] = + results; + + if ( + !tokenizerPath || + !schedulerPath || + !encoderPath || + !unetPath || + !decoderPath + ) { + throw new Error('Download interrupted.'); + } + + const response = await fetch('file://' + schedulerPath); + const schedulerConfig = await response.json(); + + this.nativeModule = global.loadTextToImage( + tokenizerPath, + encoderPath, + unetPath, + decoderPath, + schedulerConfig.beta_start, + schedulerConfig.beta_end, + schedulerConfig.num_train_timesteps, + schedulerConfig.steps_offset + ); + } + + async forward( + input: string, + imageSize: number = 512, + numSteps: number = 5, + seed?: number + ): Promise { + const output = await this.nativeModule.generate( + input, + imageSize, + numSteps, + seed ? seed : -1, + this.inferenceCallback + ); + const outputArray = new Uint8Array(output); + if (!outputArray.length) { + return ''; + } + const png = new PNG({ width: imageSize, height: imageSize }); + png.data = Buffer.from(outputArray); + const pngBuffer = PNG.sync.write(png, { colorType: 6 }); + const pngString = pngBuffer.toString('base64'); + return pngString; + } + + public interrupt(): void { + this.nativeModule.interrupt(); + } +} diff --git a/yarn.lock b/yarn.lock index 0e98536c65..902e5c5c28 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4228,6 +4228,15 @@ __metadata: languageName: node linkType: hard +"@types/pngjs@npm:^6.0.5": + version: 6.0.5 + resolution: "@types/pngjs@npm:6.0.5" + dependencies: + "@types/node": "npm:*" + checksum: 10/132fce25817d47a784ed48aa678332521b0f7e6edbaa76f3fa4e9ca1228078788ae712f99ad4d1a324d9ba0b14829958677eabf3ebef1fb6e120816f433f0cd8 + languageName: node + linkType: hard + "@types/prop-types@npm:*": version: 15.7.14 resolution: "@types/prop-types@npm:15.7.14" @@ -5836,6 +5845,7 @@ __metadata: "@react-navigation/drawer": "npm:^7.4.1" "@react-navigation/native": "npm:^7.1.10" "@shopify/react-native-skia": "npm:v2.0.0-next.2" + "@types/pngjs": "npm:^6.0.5" "@types/react": "npm:~19.0.10" expo: "npm:^53.0.0" expo-constants: "npm:~17.1.6"