-
Notifications
You must be signed in to change notification settings - Fork 69
feat: Add text to image pipeline #586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
75 commits
Select commit
Hold shift + click to select a range
70f463d
Text2image placeholder
a-szymanska a187b74
Encoder working
a-szymanska 6af3afc
Scheduler draft
a-szymanska b793eb0
Scheduler update step method
a-szymanska e514072
UNet draft
a-szymanska c0872f5
UNet working fine
a-szymanska 399a93c
VAE decoder working fine
a-szymanska 47c893a
Update app
a-szymanska 471224f
Cleaning
a-szymanska 6eef077
Update yarn.lock to include @types/pngjs
a-szymanska 0e57556
Reset text embeddings app to main
a-szymanska 86fc14b
Address PR review comments
a-szymanska 398bb68
Address PR review comments
a-szymanska aba200f
Loading scheduler params from config
a-szymanska b5d8730
Changes in scheduler - part 1
a-szymanska 98f0f5f
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska bc32200
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska f25a047
Changes in scheduler - part 2
a-szymanska 177af1b
Addressing the review
a-szymanska 10bceaf
Quick fix scheduler signature
a-szymanska 32e15bd
Changes from review in TextToImage.cpp
a-szymanska 0162df4
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska a1e422c
Address changes from review
a-szymanska 830de46
Changes from review in text2image app
a-szymanska 4b16b43
Update download callback
a-szymanska 3e44559
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska c1c04c4
Changes in scheduler - part 3
a-szymanska 56bfe07
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska d72fac8
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 997af4b
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska e9922b6
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska b11fd8f
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 1d37315
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska e2977e9
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 7e03dd5
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 3b7ce7c
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska d7ad0b2
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska b5267fe
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska dc4e706
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 8107bd0
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska f1a89c0
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 6f17eb6
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska b52c63b
Moving postprocess to the model
a-szymanska 543f9ee
Includes cleanup
a-szymanska c393c6c
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska a6394ad
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska be0076a
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska c255864
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska f042402
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 3709701
Cleanup includes
a-szymanska ea288ae
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 02d69be
Adding interrupt
a-szymanska 75bd6df
Update app layout
a-szymanska 51472a1
Adding callback on inference step
a-szymanska a419f4a
Displaying progress in app
a-szymanska ff09f21
Return base 64 from TextToImageModule
a-szymanska e323344
Manage input inside UNet
a-szymanska 4b1eb93
Extract encoder class
a-szymanska a85c258
Address changes from review
a-szymanska 70cd7ab
Bulk fetch in t2i module
a-szymanska 1d5d34d
Address changes from review
a-szymanska 5dbe18e
Address changes from review
a-szymanska 56b347c
Update packages/react-native-executorch/common/rnexecutorch/models/te…
a-szymanska 66493ee
Use dynamic shaped model, add numSteps button
a-szymanska 23d8270
One model for different image sizes
a-szymanska 694e912
Minor changes
a-szymanska 0583794
Docs and model urls
a-szymanska 3ee98d8
More changes
a-szymanska 1e3b6cc
Update docs
a-szymanska c7a831b
Adding seed value
a-szymanska ac7e754
Update docs - seed value
a-szymanska a1bbda3
Redesign demo app
a-szymanska 12f2c00
Update docs and model urls
a-szymanska 3e681ce
Minor fix in app
a-szymanska 936e114
Update model urls
a-szymanska File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,3 +66,11 @@ softmax | |
| logit | ||
| logits | ||
| probs | ||
| unet | ||
| Unet | ||
| VPRED | ||
| timesteps | ||
| Timesteps | ||
| denoises | ||
| denoise | ||
| denoising | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| import { | ||
| View, | ||
| StyleSheet, | ||
| Text, | ||
| Image, | ||
| Keyboard, | ||
| TouchableWithoutFeedback, | ||
| } from 'react-native'; | ||
| import React, { useContext, useEffect, useState } from 'react'; | ||
| import Spinner from 'react-native-loading-spinner-overlay'; | ||
| import { useTextToImage, BK_SDM_TINY_VPRED_256 } from 'react-native-executorch'; | ||
| import { GeneratingContext } from '../../context'; | ||
| import ColorPalette from '../../colors'; | ||
| import ProgressBar from '../../components/ProgressBar'; | ||
| import { BottomBarWithTextInput } from '../../components/BottomBarWithTextInput'; | ||
|
|
||
| export default function TextToImageScreen() { | ||
| const [inferenceStepIdx, setInferenceStepIdx] = useState<number>(0); | ||
| const [imageTitle, setImageTitle] = useState<string | null>(null); | ||
| const [image, setImage] = useState<string | null>(null); | ||
| const [steps, setSteps] = useState<number>(10); | ||
| const [showTextInput, setShowTextInput] = useState(false); | ||
| const [keyboardVisible, setKeyboardVisible] = useState(false); | ||
|
|
||
| const imageSize = 224; | ||
| const model = useTextToImage({ | ||
| model: BK_SDM_TINY_VPRED_256, | ||
| inferenceCallback: (x) => setInferenceStepIdx(x), | ||
| }); | ||
|
|
||
| const { setGlobalGenerating } = useContext(GeneratingContext); | ||
|
|
||
| useEffect(() => { | ||
| setGlobalGenerating(model.isGenerating); | ||
| }, [model.isGenerating, setGlobalGenerating]); | ||
|
|
||
| useEffect(() => { | ||
| const showSub = Keyboard.addListener('keyboardDidShow', () => { | ||
| setKeyboardVisible(true); | ||
| }); | ||
| const hideSub = Keyboard.addListener('keyboardDidHide', () => { | ||
| setKeyboardVisible(false); | ||
| }); | ||
| return () => { | ||
| showSub.remove(); | ||
| hideSub.remove(); | ||
| }; | ||
| }, []); | ||
|
|
||
| const runForward = async (input: string, numSteps: number) => { | ||
| if (!input || !input.trim()) return; | ||
| const prevImageTitle = imageTitle; | ||
| setImageTitle(input); | ||
| setSteps(numSteps); | ||
| try { | ||
| const output = await model.generate(input, imageSize, steps); | ||
| if (!output.length) { | ||
| setImageTitle(prevImageTitle); | ||
| return; | ||
| } | ||
| setImage(output); | ||
| } catch (e) { | ||
| console.error(e); | ||
| setImageTitle(null); | ||
| } finally { | ||
| setInferenceStepIdx(0); | ||
| } | ||
| }; | ||
|
|
||
| if (!model.isReady) { | ||
| // TODO: Update when #614 merged | ||
| return ( | ||
| <Spinner | ||
| visible={!model.isReady} | ||
| textContent={`Loading the model ${(model.downloadProgress * 100).toFixed(0)} %`} | ||
| /> | ||
| ); | ||
| } | ||
|
|
||
| return ( | ||
| <TouchableWithoutFeedback | ||
| onPress={() => { | ||
| Keyboard.dismiss(); | ||
| setShowTextInput(false); | ||
| }} | ||
| > | ||
| <View style={styles.container}> | ||
| {keyboardVisible && <View style={styles.overlay} />} | ||
|
|
||
| <View style={styles.titleContainer}> | ||
| {imageTitle && <Text style={styles.titleText}>{imageTitle}</Text>} | ||
| </View> | ||
|
|
||
| {model.isGenerating ? ( | ||
| <View style={styles.progressContainer}> | ||
| <Text style={styles.text}>Generating...</Text> | ||
| <ProgressBar numSteps={steps} currentStep={inferenceStepIdx} /> | ||
| </View> | ||
| ) : ( | ||
| <View style={styles.imageContainer}> | ||
| {image?.length ? ( | ||
| <Image | ||
| style={styles.image} | ||
| source={{ uri: `data:image/png;base64,${image}` }} | ||
| /> | ||
| ) : ( | ||
| <Image | ||
| style={styles.image} | ||
| source={require('../../assets/icons/executorch_logo.png')} | ||
| /> | ||
| )} | ||
| </View> | ||
| )} | ||
|
|
||
| <View style={styles.bottomContainer}> | ||
| <BottomBarWithTextInput | ||
| runModel={runForward} | ||
| stopModel={model.interrupt} | ||
| isGenerating={model.isGenerating} | ||
| isReady={model.isReady} | ||
| showTextInput={showTextInput} | ||
| setShowTextInput={setShowTextInput} | ||
| keyboardVisible={keyboardVisible} | ||
| /> | ||
| </View> | ||
| </View> | ||
| </TouchableWithoutFeedback> | ||
| ); | ||
| } | ||
|
|
||
| const styles = StyleSheet.create({ | ||
| container: { | ||
| flex: 1, | ||
| width: '100%', | ||
| alignItems: 'center', | ||
| }, | ||
| overlay: { | ||
| ...StyleSheet.absoluteFillObject, | ||
| backgroundColor: 'rgba(0,0,0,0.65)', | ||
| zIndex: 5, | ||
| }, | ||
| titleContainer: { | ||
| alignItems: 'center', | ||
| marginTop: 20, | ||
| }, | ||
| titleText: { | ||
| color: ColorPalette.primary, | ||
| fontSize: 20, | ||
| fontWeight: 'bold', | ||
| marginBottom: 12, | ||
| textAlign: 'center', | ||
| }, | ||
| text: { | ||
| fontSize: 16, | ||
| color: '#000', | ||
| }, | ||
| imageContainer: { | ||
| flex: 1, | ||
| position: 'absolute', | ||
| top: 100, | ||
| alignItems: 'center', | ||
| justifyContent: 'center', | ||
| }, | ||
| image: { | ||
| width: 256, | ||
| height: 256, | ||
| marginVertical: 30, | ||
| resizeMode: 'contain', | ||
| }, | ||
| progressContainer: { | ||
| flex: 1, | ||
| justifyContent: 'center', | ||
| alignItems: 'center', | ||
| }, | ||
| bottomContainer: { | ||
| flex: 1, | ||
| width: '90%', | ||
| position: 'absolute', | ||
| bottom: 0, | ||
| marginBottom: 25, | ||
| zIndex: 10, | ||
| }, | ||
| }); |
165 changes: 165 additions & 0 deletions
165
apps/computer-vision/components/BottomBarWithTextInput.tsx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| import React, { useState } from 'react'; | ||
| import { | ||
| View, | ||
| Text, | ||
| TextInput, | ||
| TouchableOpacity, | ||
| StyleSheet, | ||
| KeyboardAvoidingView, | ||
| Platform, | ||
| } from 'react-native'; | ||
| import { Ionicons } from '@expo/vector-icons'; | ||
| import ColorPalette from '../colors'; | ||
|
|
||
| interface BottomBarProps { | ||
| runModel: (input: string, numSteps: number) => void; | ||
| stopModel: () => void; | ||
| isGenerating?: boolean; | ||
| isReady?: boolean; | ||
| showTextInput: boolean; | ||
| setShowTextInput: React.Dispatch<React.SetStateAction<boolean>>; | ||
| keyboardVisible: boolean; | ||
| } | ||
|
|
||
| export const BottomBarWithTextInput = ({ | ||
| runModel, | ||
| stopModel, | ||
| isGenerating, | ||
| isReady, | ||
| showTextInput, | ||
| setShowTextInput, | ||
| keyboardVisible, | ||
| }: BottomBarProps) => { | ||
| const [input, setInput] = useState(''); | ||
| const [numSteps, setNumSteps] = useState(10); | ||
|
|
||
| const decreaseSteps = () => setNumSteps((prev) => Math.max(5, prev - 5)); | ||
| const increaseSteps = () => setNumSteps((prev) => Math.min(50, prev + 5)); | ||
|
|
||
| if (!showTextInput) { | ||
| if (isGenerating) { | ||
| return ( | ||
| <TouchableOpacity | ||
| style={styles.button} | ||
| onPress={stopModel} | ||
| disabled={!isReady} | ||
| > | ||
| <Text style={styles.buttonText}>Stop model</Text> | ||
| </TouchableOpacity> | ||
| ); | ||
| } else { | ||
| return ( | ||
| <TouchableOpacity | ||
| style={styles.button} | ||
| onPress={() => setShowTextInput(true)} | ||
| disabled={!isReady} | ||
| > | ||
| <Text style={styles.buttonText}>Run model</Text> | ||
| </TouchableOpacity> | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| return ( | ||
| <KeyboardAvoidingView | ||
| style={styles.container} | ||
| collapsable={false} | ||
| behavior={Platform.OS === 'ios' ? 'padding' : undefined} | ||
| keyboardVerticalOffset={Platform.OS === 'ios' ? 120 : 40} | ||
| > | ||
| <View style={styles.inputContainer}> | ||
| <TextInput | ||
| style={styles.input} | ||
| placeholder="Enter prompt..." | ||
| value={input} | ||
| onChangeText={setInput} | ||
| /> | ||
| <TouchableOpacity | ||
| style={[styles.button, styles.iconButton]} | ||
| onPress={() => { | ||
| setShowTextInput(false); | ||
| setInput(''); | ||
| runModel(input, numSteps); | ||
| }} | ||
| disabled={!isReady || isGenerating} | ||
| > | ||
| <Ionicons name="send" size={20} color="#fff" /> | ||
| </TouchableOpacity> | ||
| </View> | ||
|
|
||
| <View style={styles.stepsContainer}> | ||
| <Text style={[styles.text, keyboardVisible && styles.textWhite]}> | ||
| Steps: {numSteps} | ||
| </Text> | ||
| <View style={styles.stepsButtons}> | ||
| <TouchableOpacity | ||
| style={[styles.button, styles.iconButton]} | ||
| onPress={decreaseSteps} | ||
| > | ||
| <Text style={styles.buttonText}>-</Text> | ||
| </TouchableOpacity> | ||
| <TouchableOpacity | ||
| style={[styles.button, styles.iconButton]} | ||
| onPress={increaseSteps} | ||
| > | ||
| <Text style={styles.buttonText}>+</Text> | ||
| </TouchableOpacity> | ||
| </View> | ||
| </View> | ||
| </KeyboardAvoidingView> | ||
| ); | ||
| }; | ||
|
|
||
| const styles = StyleSheet.create({ | ||
| container: { | ||
| alignItems: 'center', | ||
| }, | ||
| inputContainer: { | ||
| flexDirection: 'row', | ||
| alignItems: 'center', | ||
| justifyContent: 'center', | ||
| }, | ||
| input: { | ||
| flex: 1, | ||
| borderRadius: 6, | ||
| padding: 8, | ||
| marginRight: 8, | ||
| backgroundColor: '#fff', | ||
| color: '#000', | ||
| }, | ||
| stepsContainer: { | ||
| width: '100%', | ||
| flexDirection: 'row', | ||
| alignItems: 'center', | ||
| justifyContent: 'space-between', | ||
| marginTop: 10, | ||
| }, | ||
| stepsButtons: { | ||
| flexDirection: 'row', | ||
| }, | ||
| button: { | ||
| width: '100%', | ||
| height: 40, | ||
| justifyContent: 'center', | ||
| alignItems: 'center', | ||
| backgroundColor: ColorPalette.primary, | ||
| borderRadius: 8, | ||
| }, | ||
| buttonText: { | ||
| color: '#fff', | ||
| fontSize: 16, | ||
| textAlign: 'center', | ||
| }, | ||
| iconButton: { | ||
| marginHorizontal: 5, | ||
| width: 40, | ||
| }, | ||
| text: { | ||
| flex: 1, | ||
| fontSize: 16, | ||
| color: '#000', | ||
| }, | ||
| textWhite: { | ||
| color: '#fff', | ||
| }, | ||
| }); | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI this can be pull inside
useCallback