Skip to content

Commit d97781d

Browse files
authored
feat: pose estimation (#1100)
## Description Adds a new pose estimation module to react-native-executorch, mirroring the object detection API surface. - Native (C++): pose estimation runtime built on VisionModel, taking only modelSource + optional normMean / normStd. numKeypoints is derived from the model's output tensor shape — no keypointNames plumbed through the bridge. - TS API: PoseEstimationModule (fromModelName / fromCustomModel) and usePoseEstimation hook. Return type is statically tied to keypointMapRuntime maps the native positional array into the named record at the boundary, in both forward and the runOnFrame worklet. - Built-in models: yolo26n-pose (i will update more model sizes soon) - Custom models: supported via fromCustomModel with user-supplied keypointMap; output type follows. Required model contract (input shape, optional normalization, three output tensors boxes / scores / keypoints) is documented. - Docs: new usePoseEstimation and PoseEstimationModule pages following the object detection structure, plus VisionCamera integration entry. ### Introduces a breaking change? - [ ] Yes - [x] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [x] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions - [x] Run native sanity tests and verify that they pass - [ ] Run the computer vision example app, both for vision camera and static images. ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 33b2fab commit d97781d

30 files changed

Lines changed: 1978 additions & 44 deletions

File tree

.cspell-wordlist.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,8 @@ BIOES
193193
viterbi
194194
argmaxes
195195
unpadded
196+
keypoint
197+
keypoints
198+
Keypoint
199+
Keypoints
200+
letterboxing

.eslintrc.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ const VALID_CATEGORIES = [
1010
'Models - LLM',
1111
'Models - Object Detection',
1212
'Models - Instance Segmentation',
13+
'Models - Pose Estimation',
1314
'Models - Semantic Segmentation',
1415
'Models - Speech To Text',
1516
'Models - Style Transfer',

apps/computer-vision/app/_layout.tsx

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,14 @@ export default function _layout() {
149149
headerTitleStyle: { color: ColorPalette.primary },
150150
}}
151151
/>
152+
<Drawer.Screen
153+
name="pose_estimation/index"
154+
options={{
155+
drawerLabel: 'Pose Estimation',
156+
title: 'Pose Estimation',
157+
headerTitleStyle: { color: ColorPalette.primary },
158+
}}
159+
/>
152160
<Drawer.Screen
153161
name="ocr/index"
154162
options={{

apps/computer-vision/app/index.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ export default function Home() {
4141
>
4242
<Text style={styles.buttonText}>Instance Segmentation</Text>
4343
</TouchableOpacity>
44+
<TouchableOpacity
45+
style={styles.button}
46+
onPress={() => router.navigate('pose_estimation/')}
47+
>
48+
<Text style={styles.buttonText}>Pose Estimation</Text>
49+
</TouchableOpacity>
4450
<TouchableOpacity
4551
style={styles.button}
4652
onPress={() => router.navigate('ocr/')}
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import Spinner from '../../components/Spinner';
2+
import { BottomBar } from '../../components/BottomBar';
3+
import { getImage } from '../../utils';
4+
import {
5+
usePoseEstimation,
6+
PoseDetections,
7+
RnExecutorchError,
8+
RnExecutorchErrorCode,
9+
YOLO26N_POSE,
10+
} from 'react-native-executorch';
11+
import { View, StyleSheet, Image, Text } from 'react-native';
12+
import React, { useContext, useEffect, useState } from 'react';
13+
import { GeneratingContext } from '../../context';
14+
import ScreenWrapper from '../../ScreenWrapper';
15+
import { StatsBar } from '../../components/StatsBar';
16+
import Svg, { Circle, Line } from 'react-native-svg';
17+
import ErrorBanner from '../../components/ErrorBanner';
18+
import { COCO_SKELETON_CONNECTIONS } from '../../components/utils/cocoSkeleton';
19+
20+
// Colors for different people
21+
const PERSON_COLORS = ['lime', 'cyan', 'magenta', 'yellow', 'orange', 'pink'];
22+
23+
export default function PoseEstimationScreen() {
24+
const [imageUri, setImageUri] = useState('');
25+
const [results, setResults] = useState<PoseDetections>([]);
26+
const [error, setError] = useState<string | null>(null);
27+
const [imageDimensions, setImageDimensions] = useState<{
28+
width: number;
29+
height: number;
30+
}>();
31+
const [inferenceTime, setInferenceTime] = useState<number | null>(null);
32+
const [layout, setLayout] = useState({ width: 0, height: 0 });
33+
34+
const model = usePoseEstimation({ model: YOLO26N_POSE });
35+
const { setGlobalGenerating } = useContext(GeneratingContext);
36+
37+
useEffect(() => {
38+
setGlobalGenerating(model.isGenerating);
39+
}, [model.isGenerating, setGlobalGenerating]);
40+
41+
useEffect(() => {
42+
if (model.error) setError(String(model.error));
43+
}, [model.error]);
44+
45+
const handleCameraPress = async (isCamera: boolean) => {
46+
const image = await getImage(isCamera);
47+
const uri = image?.uri;
48+
const width = image?.width;
49+
const height = image?.height;
50+
51+
if (uri && width && height) {
52+
setImageUri(image.uri as string);
53+
setImageDimensions({ width, height });
54+
setResults([]);
55+
setInferenceTime(null);
56+
}
57+
};
58+
59+
const runForward = async () => {
60+
if (imageUri) {
61+
try {
62+
const start = Date.now();
63+
const output = await model.forward(imageUri, { inputSize: 384 });
64+
setInferenceTime(Date.now() - start);
65+
setResults(output);
66+
} catch (e) {
67+
if (e instanceof RnExecutorchError) {
68+
switch (e.code) {
69+
case RnExecutorchErrorCode.FileReadFailed:
70+
setError('Could not read the selected image.');
71+
break;
72+
case RnExecutorchErrorCode.ModelGenerating:
73+
setError('Model is busy — wait for the current run to finish.');
74+
break;
75+
case RnExecutorchErrorCode.InvalidUserInput:
76+
case RnExecutorchErrorCode.InvalidArgument:
77+
setError(`Invalid input: ${e.message}`);
78+
break;
79+
default:
80+
setError(e.message);
81+
}
82+
} else {
83+
setError(e instanceof Error ? e.message : String(e));
84+
}
85+
}
86+
}
87+
};
88+
89+
if (!model.isReady) {
90+
return (
91+
<Spinner
92+
visible={!model.isReady}
93+
textContent={`Loading the model ${(model.downloadProgress * 100).toFixed(0)} %`}
94+
/>
95+
);
96+
}
97+
98+
return (
99+
<ScreenWrapper>
100+
<ErrorBanner message={error} onDismiss={() => setError(null)} />
101+
<View style={styles.imageContainer}>
102+
<View style={styles.image}>
103+
{imageUri && imageDimensions?.width && imageDimensions?.height ? (
104+
<View
105+
style={styles.imageWrapper}
106+
onLayout={(e) =>
107+
setLayout({
108+
width: e.nativeEvent.layout.width,
109+
height: e.nativeEvent.layout.height,
110+
})
111+
}
112+
>
113+
<Image
114+
source={{ uri: imageUri }}
115+
style={styles.fullSizeImage}
116+
resizeMode="contain"
117+
/>
118+
{results.length > 0 &&
119+
layout.width > 0 &&
120+
layout.height > 0 &&
121+
(() => {
122+
// Account for resizeMode="contain" letterboxing: the image's
123+
// displayed area is smaller than the container in one axis.
124+
const imageRatio =
125+
imageDimensions.width / imageDimensions.height;
126+
const layoutRatio = layout.width / layout.height;
127+
let scaleX: number, scaleY: number;
128+
if (imageRatio > layoutRatio) {
129+
scaleX = layout.width / imageDimensions.width;
130+
scaleY = layout.width / imageRatio / imageDimensions.height;
131+
} else {
132+
scaleY = layout.height / imageDimensions.height;
133+
scaleX =
134+
(layout.height * imageRatio) / imageDimensions.width;
135+
}
136+
const offsetX =
137+
(layout.width - imageDimensions.width * scaleX) / 2;
138+
const offsetY =
139+
(layout.height - imageDimensions.height * scaleY) / 2;
140+
const isInBounds = (kp: { x: number; y: number }) =>
141+
kp.x >= 0 &&
142+
kp.y >= 0 &&
143+
kp.x <= imageDimensions.width &&
144+
kp.y <= imageDimensions.height;
145+
return (
146+
<Svg style={StyleSheet.absoluteFill}>
147+
{results.map((personKeypoints, personIdx) => {
148+
const color =
149+
PERSON_COLORS[personIdx % PERSON_COLORS.length];
150+
return (
151+
<React.Fragment key={`person-${personIdx}`}>
152+
{COCO_SKELETON_CONNECTIONS.map(
153+
([from, to], lineIdx) => {
154+
const kp1 = personKeypoints[from];
155+
const kp2 = personKeypoints[to];
156+
if (!kp1 || !kp2) return null;
157+
if (!isInBounds(kp1) || !isInBounds(kp2))
158+
return null;
159+
return (
160+
<Line
161+
key={`person-${personIdx}-line-${lineIdx}`}
162+
x1={kp1.x * scaleX + offsetX}
163+
y1={kp1.y * scaleY + offsetY}
164+
x2={kp2.x * scaleX + offsetX}
165+
y2={kp2.y * scaleY + offsetY}
166+
stroke={color}
167+
strokeWidth="2"
168+
/>
169+
);
170+
}
171+
)}
172+
{Object.entries(personKeypoints)
173+
.filter(([, kp]) => isInBounds(kp))
174+
.map(([name, kp]) => (
175+
<Circle
176+
key={`person-${personIdx}-kp-${name}`}
177+
cx={kp.x * scaleX + offsetX}
178+
cy={kp.y * scaleY + offsetY}
179+
r="4"
180+
fill="red"
181+
/>
182+
))}
183+
</React.Fragment>
184+
);
185+
})}
186+
</Svg>
187+
);
188+
})()}
189+
</View>
190+
) : (
191+
<Image
192+
style={styles.fullSizeImage}
193+
resizeMode="contain"
194+
source={require('../../assets/icons/executorch_logo.png')}
195+
/>
196+
)}
197+
</View>
198+
{!imageUri && (
199+
<View style={styles.infoContainer}>
200+
<Text style={styles.infoTitle}>Pose Estimation</Text>
201+
<Text style={styles.infoText}>
202+
This model detects human body keypoints (17 COCO keypoints) and
203+
draws a skeleton overlay. Pick an image from your gallery or take
204+
one with your camera to get started.
205+
</Text>
206+
</View>
207+
)}
208+
</View>
209+
<StatsBar
210+
inferenceTime={inferenceTime}
211+
detectionCount={results.length > 0 ? results.length : null}
212+
/>
213+
<BottomBar
214+
handleCameraPress={handleCameraPress}
215+
runForward={runForward}
216+
hasImage={!!imageUri}
217+
isGenerating={model.isGenerating}
218+
/>
219+
</ScreenWrapper>
220+
);
221+
}
222+
223+
const styles = StyleSheet.create({
224+
imageContainer: {
225+
flex: 6,
226+
width: '100%',
227+
padding: 16,
228+
},
229+
image: {
230+
flex: 2,
231+
borderRadius: 8,
232+
width: '100%',
233+
},
234+
imageWrapper: {
235+
flex: 1,
236+
width: '100%',
237+
height: '100%',
238+
},
239+
fullSizeImage: {
240+
width: '100%',
241+
height: '100%',
242+
},
243+
infoContainer: {
244+
alignItems: 'center',
245+
padding: 16,
246+
gap: 8,
247+
},
248+
infoTitle: {
249+
fontSize: 18,
250+
fontWeight: '600',
251+
color: 'navy',
252+
},
253+
infoText: {
254+
fontSize: 14,
255+
color: '#555',
256+
textAlign: 'center',
257+
lineHeight: 20,
258+
},
259+
});

apps/computer-vision/app/vision_camera/index.tsx

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import SegmentationTask from '../../components/vision_camera/tasks/SegmentationT
2828
import InstanceSegmentationTask from '../../components/vision_camera/tasks/InstanceSegmentationTask';
2929
import OCRTask from '../../components/vision_camera/tasks/OCRTask';
3030
import StyleTransferTask from '../../components/vision_camera/tasks/StyleTransferTask';
31+
import PoseEstimationTask from '../../components/vision_camera/tasks/PoseEstimationTask';
3132
// 1. Import ErrorBanner
3233
import ErrorBanner from '../../components/ErrorBanner';
3334

@@ -36,6 +37,7 @@ type TaskId =
3637
| 'objectDetection'
3738
| 'segmentation'
3839
| 'instanceSegmentation'
40+
| 'poseEstimation'
3941
| 'ocr'
4042
| 'styleTransfer';
4143
type ModelId =
@@ -52,6 +54,7 @@ type ModelId =
5254
| 'segmentationSelfie'
5355
| 'instanceSegmentationYolo26n'
5456
| 'instanceSegmentationRfdetr'
57+
| 'poseEstimationYolo26n'
5558
| 'ocr'
5659
| 'styleTransferCandy'
5760
| 'styleTransferMosaic';
@@ -86,6 +89,11 @@ const TASKS: Task[] = [
8689
{ id: 'instanceSegmentationRfdetr', label: 'RF-DETR Nano Seg' },
8790
],
8891
},
92+
{
93+
id: 'poseEstimation',
94+
label: 'Pose',
95+
variants: [{ id: 'poseEstimationYolo26n', label: 'YOLO26N Pose' }],
96+
},
8997
{
9098
id: 'objectDetection',
9199
label: 'Detect',
@@ -223,6 +231,12 @@ export default function VisionCameraScreen() {
223231
outputs={frameOutput ? [frameOutput] : []}
224232
isActive={isFocused}
225233
orientationSource="device"
234+
onError={(e) => {
235+
console.warn('[Camera] onError', e);
236+
setError(e.message);
237+
}}
238+
onStarted={() => console.log('[Camera] session started')}
239+
onPreviewStarted={() => console.log('[Camera] preview got first frame')}
226240
/>
227241

228242
<View
@@ -273,6 +287,12 @@ export default function VisionCameraScreen() {
273287
}
274288
/>
275289
)}
290+
{activeTask === 'poseEstimation' && (
291+
<PoseEstimationTask
292+
{...taskProps}
293+
activeModel={activeModel as 'poseEstimationYolo26n'}
294+
/>
295+
)}
276296
{activeTask === 'ocr' && <OCRTask {...taskProps} />}
277297
{activeTask === 'styleTransfer' && (
278298
<StyleTransferTask
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
export const COCO_SKELETON_CONNECTIONS = [
2+
['NOSE', 'LEFT_EYE'],
3+
['NOSE', 'RIGHT_EYE'],
4+
['LEFT_EYE', 'LEFT_EAR'],
5+
['RIGHT_EYE', 'RIGHT_EAR'],
6+
['LEFT_SHOULDER', 'RIGHT_SHOULDER'],
7+
['LEFT_SHOULDER', 'LEFT_ELBOW'],
8+
['LEFT_ELBOW', 'LEFT_WRIST'],
9+
['RIGHT_SHOULDER', 'RIGHT_ELBOW'],
10+
['RIGHT_ELBOW', 'RIGHT_WRIST'],
11+
['LEFT_SHOULDER', 'LEFT_HIP'],
12+
['RIGHT_SHOULDER', 'RIGHT_HIP'],
13+
['LEFT_HIP', 'RIGHT_HIP'],
14+
['LEFT_HIP', 'LEFT_KNEE'],
15+
['LEFT_KNEE', 'LEFT_ANKLE'],
16+
['RIGHT_HIP', 'RIGHT_KNEE'],
17+
['RIGHT_KNEE', 'RIGHT_ANKLE'],
18+
] as const;

0 commit comments

Comments
 (0)