Skip to content

Commit e2531bc

Browse files
committed
chore: review suggestions
1 parent df1ef5f commit e2531bc

7 files changed

Lines changed: 110 additions & 64 deletions

File tree

apps/computer-vision/app/pose_estimation/index.tsx

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,11 @@ export default function PoseEstimationScreen() {
155155
(layout.width - imageDimensions.width * scaleX) / 2;
156156
const offsetY =
157157
(layout.height - imageDimensions.height * scaleY) / 2;
158+
const isInBounds = (kp: { x: number; y: number }) =>
159+
kp.x >= 0 &&
160+
kp.y >= 0 &&
161+
kp.x <= imageDimensions.width &&
162+
kp.y <= imageDimensions.height;
158163
return (
159164
<Svg style={StyleSheet.absoluteFill}>
160165
{results.map((personKeypoints, personIdx) => {
@@ -167,6 +172,8 @@ export default function PoseEstimationScreen() {
167172
const kp1 = personKeypoints[from];
168173
const kp2 = personKeypoints[to];
169174
if (!kp1 || !kp2) return null;
175+
if (!isInBounds(kp1) || !isInBounds(kp2))
176+
return null;
170177
return (
171178
<Line
172179
key={`person-${personIdx}-line-${lineIdx}`}
@@ -180,17 +187,17 @@ export default function PoseEstimationScreen() {
180187
);
181188
}
182189
)}
183-
{Object.entries(personKeypoints).map(
184-
([name, kp]) => (
190+
{Object.entries(personKeypoints)
191+
.filter(([, kp]) => isInBounds(kp))
192+
.map(([name, kp]) => (
185193
<Circle
186194
key={`person-${personIdx}-kp-${name}`}
187195
cx={kp.x * scaleX + offsetX}
188196
cy={kp.y * scaleY + offsetY}
189197
r="4"
190198
fill="red"
191199
/>
192-
)
193-
)}
200+
))}
194201
</React.Fragment>
195202
);
196203
})}

apps/computer-vision/components/vision_camera/tasks/PoseEstimationTask.tsx

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,16 @@ export default function PoseEstimationTask({
142142
<Svg style={StyleSheet.absoluteFill}>
143143
{detections.map((personKeypoints, personIdx) => {
144144
const color = PERSON_COLORS[personIdx % PERSON_COLORS.length];
145+
const isVisible = (kp: { x: number; y: number }) =>
146+
kp.x >= 0 && kp.y >= 0;
145147
return (
146148
<React.Fragment key={`person-${personIdx}`}>
147149
{/* Draw skeleton lines */}
148150
{COCO_SKELETON_CONNECTIONS.map(([from, to], lineIdx) => {
149151
const kp1 = personKeypoints[from];
150152
const kp2 = personKeypoints[to];
151153
if (!kp1 || !kp2) return null;
154+
if (!isVisible(kp1) || !isVisible(kp2)) return null;
152155
const x1 = kp1.x * scale + offsetX;
153156
const y1 = kp1.y * scale + offsetY;
154157
const x2 = kp2.x * scale + offsetX;
@@ -166,19 +169,21 @@ export default function PoseEstimationTask({
166169
);
167170
})}
168171
{/* Draw keypoints */}
169-
{Object.entries(personKeypoints).map(([name, kp]) => {
170-
const cx = kp.x * scale + offsetX;
171-
const cy = kp.y * scale + offsetY;
172-
return (
173-
<Circle
174-
key={`person-${personIdx}-kp-${name}`}
175-
cx={cx}
176-
cy={cy}
177-
r={5}
178-
fill="red"
179-
/>
180-
);
181-
})}
172+
{Object.entries(personKeypoints)
173+
.filter(([, kp]) => isVisible(kp))
174+
.map(([name, kp]) => {
175+
const cx = kp.x * scale + offsetX;
176+
const cy = kp.y * scale + offsetY;
177+
return (
178+
<Circle
179+
key={`person-${personIdx}-kp-${name}`}
180+
cx={cx}
181+
cy={cy}
182+
r={5}
183+
fill="red"
184+
/>
185+
);
186+
})}
182187
</React.Fragment>
183188
);
184189
})}

packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ PoseEstimation::PoseEstimation(const std::string &modelSource,
3232

3333
PoseDetections PoseEstimation::postprocess(const std::vector<EValue> &tensors,
3434
cv::Size originalSize,
35-
double detectionThreshold) {
35+
double detectionThreshold,
36+
double keypointThreshold) {
3637
// Output tensors (batch dim squeezed):
3738
// 0: boxes (Q, 4) - xyxy bbox in model input pixel space
3839
// 1: scores (Q,) - person confidence [0, 1]
@@ -75,6 +76,11 @@ PoseDetections PoseEstimation::postprocess(const std::vector<EValue> &tensors,
7576
const float *detectionKps = kpData + i * numKeypoints * 3;
7677

7778
for (size_t k = 0; k < numKeypoints; ++k) {
79+
float visibility = detectionKps[k * 3 + 2];
80+
if (visibility < keypointThreshold) {
81+
keypoints.emplace_back(-1, -1);
82+
continue;
83+
}
7884
float x = detectionKps[k * 3];
7985
float y = detectionKps[k * 3 + 1];
8086

@@ -92,7 +98,7 @@ PoseDetections PoseEstimation::postprocess(const std::vector<EValue> &tensors,
9298

9399
PoseDetections PoseEstimation::runInference(cv::Mat image,
94100
double detectionThreshold,
95-
double iouThreshold,
101+
double keypointThreshold,
96102
const std::string &methodName) {
97103

98104
log(LOG_LEVEL::Debug, "Running inference with model name: " + methodName);
@@ -101,9 +107,9 @@ PoseDetections PoseEstimation::runInference(cv::Mat image,
101107
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
102108
"detectionThreshold must be in range [0, 1]");
103109
}
104-
if (iouThreshold < 0.0 || iouThreshold > 1.0) {
110+
if (keypointThreshold < 0.0 || keypointThreshold > 1.0) {
105111
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
106-
"iouThreshold must be in range [0, 1]");
112+
"keypointThreshold must be in range [0, 1]");
107113
}
108114

109115
std::scoped_lock lock(inference_mutex_);
@@ -132,30 +138,31 @@ PoseDetections PoseEstimation::runInference(cv::Mat image,
132138
"Ensure the model input is correct.");
133139
}
134140

135-
return postprocess(executeResult.get(), originalSize, detectionThreshold);
141+
return postprocess(executeResult.get(), originalSize, detectionThreshold,
142+
keypointThreshold);
136143
}
137144

138145
PoseDetections PoseEstimation::generateFromString(std::string imageSource,
139146
double detectionThreshold,
140-
double iouThreshold,
147+
double keypointThreshold,
141148
std::string methodName) {
142149
cv::Mat imageBGR = image_processing::readImage(imageSource);
143150
cv::Mat imageRGB;
144151
cv::cvtColor(imageBGR, imageRGB, cv::COLOR_BGR2RGB);
145-
return runInference(std::move(imageRGB), detectionThreshold, iouThreshold,
146-
methodName);
152+
return runInference(std::move(imageRGB), detectionThreshold,
153+
keypointThreshold, methodName);
147154
}
148155

149156
PoseDetections PoseEstimation::generateFromFrame(jsi::Runtime &runtime,
150157
const jsi::Value &frameData,
151158
double detectionThreshold,
152-
double iouThreshold,
159+
double keypointThreshold,
153160
std::string methodName) {
154161
auto orient = ::rnexecutorch::utils::readFrameOrientation(runtime, frameData);
155162
cv::Mat frame = extractFromFrame(runtime, frameData);
156163
cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel(frame, orient);
157164
auto detections =
158-
runInference(rotated, detectionThreshold, iouThreshold, methodName);
165+
runInference(rotated, detectionThreshold, keypointThreshold, methodName);
159166
for (auto &person : detections) {
160167
::rnexecutorch::utils::inverseRotatePoints(person, orient, rotated.size());
161168
}
@@ -164,10 +171,10 @@ PoseDetections PoseEstimation::generateFromFrame(jsi::Runtime &runtime,
164171

165172
PoseDetections PoseEstimation::generateFromPixels(JSTensorViewIn pixelData,
166173
double detectionThreshold,
167-
double iouThreshold,
174+
double keypointThreshold,
168175
std::string methodName) {
169176
cv::Mat image = extractFromPixels(pixelData);
170-
return runInference(image, detectionThreshold, iouThreshold, methodName);
177+
return runInference(image, detectionThreshold, keypointThreshold, methodName);
171178
}
172179

173180
} // namespace rnexecutorch::models::pose_estimation

packages/react-native-executorch/common/rnexecutorch/models/pose_estimation/PoseEstimation.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,28 @@ class PoseEstimation : public VisionModel {
1717

1818
[[nodiscard("Registered non-void function")]] PoseDetections
1919
generateFromString(std::string imageSource, double detectionThreshold,
20-
double iouThreshold, std::string methodName);
20+
double keypointThreshold, std::string methodName);
2121
[[nodiscard("Registered non-void function")]] PoseDetections
2222
generateFromFrame(jsi::Runtime &runtime, const jsi::Value &frameData,
23-
double detectionThreshold, double iouThreshold,
23+
double detectionThreshold, double keypointThreshold,
2424
std::string methodName);
2525
[[nodiscard("Registered non-void function")]] PoseDetections
2626
generateFromPixels(JSTensorViewIn pixelData, double detectionThreshold,
27-
double iouThreshold, std::string methodName);
27+
double keypointThreshold, std::string methodName);
2828

2929
private:
3030
std::optional<cv::Scalar> normMean_;
3131
std::optional<cv::Scalar> normStd_;
3232

3333
[[nodiscard("Registered non-void function")]]
3434
PoseDetections runInference(cv::Mat image, double detectionThreshold,
35-
double iouThreshold,
35+
double keypointThreshold,
3636
const std::string &modelName);
3737

3838
[[nodiscard("Registered non-void function")]]
3939
PoseDetections postprocess(const std::vector<EValue> &evl,
40-
cv::Size originalSize, double detectionThreshold);
40+
cv::Size originalSize, double detectionThreshold,
41+
double keypointThreshold);
4142
};
4243

4344
} // namespace models::pose_estimation

packages/react-native-executorch/common/rnexecutorch/tests/integration/PoseEstimationTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ TEST(PoseEstimationGenerateTests, DetectionThresholdAboveOneThrows) {
8282
RnExecutorchError);
8383
}
8484

85-
TEST(PoseEstimationGenerateTests, NegativeIouThresholdThrows) {
85+
TEST(PoseEstimationGenerateTests, NegativeKeypointThresholdThrows) {
8686
PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr);
8787
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, -0.1,
8888
kMethodName),
8989
RnExecutorchError);
9090
}
9191

92-
TEST(PoseEstimationGenerateTests, IouThresholdAboveOneThrows) {
92+
TEST(PoseEstimationGenerateTests, KeypointThresholdAboveOneThrows) {
9393
PoseEstimation model(kValidPoseModelPath, {}, {}, nullptr);
9494
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 1.1,
9595
kMethodName),

packages/react-native-executorch/src/modules/computer_vision/PoseEstimationModule.ts

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ const YOLO_POSE_CONFIG = {
2222
availableInputSizes: [384, 512, 640] as const,
2323
defaultInputSize: 384,
2424
defaultDetectionThreshold: 0.5,
25-
defaultIouThreshold: 0.5,
25+
defaultKeypointThreshold: 0.5,
2626
} satisfies PoseEstimationConfig<typeof CocoKeypoint>;
2727

2828
const ModelConfigs = {
@@ -48,6 +48,26 @@ type ModelNameOf<C extends PoseEstimationModelSources> = C['modelName'];
4848
type ResolveKeypoints<T extends PoseEstimationModelName | KeypointEnum> =
4949
ResolveConfigOrType<T, ModelConfigsType, 'keypointMap'>;
5050

51+
function mapPersonKeypoints<K extends KeypointEnum>(
52+
raw: Keypoint[][],
53+
entries: [string, number][],
54+
maxIndex: number
55+
): PersonKeypoints<K>[] {
56+
'worklet';
57+
if (raw.length > 0 && raw[0]!.length <= maxIndex) {
58+
throw new Error(
59+
`Keypoint map references index ${maxIndex} but model returned ${raw[0]!.length} keypoints per person — keypointMap is incompatible with this model.`
60+
);
61+
}
62+
const out: PersonKeypoints<K>[] = [];
63+
for (const person of raw) {
64+
const named: Record<string, Keypoint> = {};
65+
for (const [name, idx] of entries) named[name] = person[idx]!;
66+
out.push(named as PersonKeypoints<K>);
67+
}
68+
return out;
69+
}
70+
5171
/**
5272
* Pose estimation module for detecting human body keypoints.
5373
* @typeParam T - Either a built-in model name (e.g. `'yolo26n-pose'`)
@@ -59,6 +79,7 @@ export class PoseEstimationModule<
5979
> extends VisionModule<PoseDetections<ResolveKeypoints<T>>> {
6080
private readonly keypointMap: ResolveKeypoints<T>;
6181
private readonly modelConfig: PoseEstimationConfig<KeypointEnum>;
82+
private readonly maxKeypointIndex: number;
6283

6384
private constructor(
6485
keypointMap: ResolveKeypoints<T>,
@@ -69,6 +90,7 @@ export class PoseEstimationModule<
6990
this.keypointMap = keypointMap;
7091
this.modelConfig = modelConfig;
7192
this.nativeModule = nativeModule;
93+
this.maxKeypointIndex = Math.max(...Object.values(keypointMap));
7294
}
7395

7496
/**
@@ -169,14 +191,12 @@ export class PoseEstimationModule<
169191
const nativeGenerateFromFrame = this.nativeModule.generateFromFrame;
170192
const defaultDetectionThreshold =
171193
this.modelConfig.defaultDetectionThreshold ?? 0.5;
172-
const defaultIouThreshold = this.modelConfig.defaultIouThreshold ?? 0.5;
194+
const defaultKeypointThreshold =
195+
this.modelConfig.defaultKeypointThreshold ?? 0.5;
173196
const defaultInputSize = this.modelConfig.defaultInputSize;
174197
const availableInputSizes = this.modelConfig.availableInputSizes;
175-
const keypointEntries = Object.entries(this.keypointMap) as [
176-
string,
177-
number,
178-
][];
179-
198+
const keypointEntries = Object.entries(this.keypointMap);
199+
const maxKeypointIndex = this.maxKeypointIndex;
180200
return (
181201
frame: Frame,
182202
isFrontCamera: boolean,
@@ -186,7 +206,8 @@ export class PoseEstimationModule<
186206

187207
const detectionThreshold =
188208
options?.detectionThreshold ?? defaultDetectionThreshold;
189-
const iouThreshold = options?.iouThreshold ?? defaultIouThreshold;
209+
const keypointThreshold =
210+
options?.keypointThreshold ?? defaultKeypointThreshold;
190211
const inputSize = options?.inputSize ?? defaultInputSize;
191212

192213
// Validate inputSize
@@ -214,16 +235,14 @@ export class PoseEstimationModule<
214235
const raw: Keypoint[][] = nativeGenerateFromFrame(
215236
frameData,
216237
detectionThreshold,
217-
iouThreshold,
238+
keypointThreshold,
218239
methodName
219240
);
220-
const out: PersonKeypoints<ResolveKeypoints<T>>[] = [];
221-
for (const person of raw) {
222-
const named: Record<string, Keypoint> = {};
223-
for (const [name, idx] of keypointEntries) named[name] = person[idx]!;
224-
out.push(named as PersonKeypoints<ResolveKeypoints<T>>);
225-
}
226-
return out;
241+
return mapPersonKeypoints<ResolveKeypoints<T>>(
242+
raw,
243+
keypointEntries,
244+
maxKeypointIndex
245+
);
227246
} finally {
228247
if (nativeBuffer?.release) {
229248
nativeBuffer.release();
@@ -253,8 +272,10 @@ export class PoseEstimationModule<
253272
options?.detectionThreshold ??
254273
this.modelConfig.defaultDetectionThreshold ??
255274
0.5;
256-
const iouThreshold =
257-
options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.5;
275+
const keypointThreshold =
276+
options?.keypointThreshold ??
277+
this.modelConfig.defaultKeypointThreshold ??
278+
0.5;
258279
const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize;
259280

260281
// Validate inputSize against availableInputSizes
@@ -277,21 +298,21 @@ export class PoseEstimationModule<
277298
? await this.nativeModule.generateFromString(
278299
input,
279300
detectionThreshold,
280-
iouThreshold,
301+
keypointThreshold,
281302
methodName
282303
)
283304
: await this.nativeModule.generateFromPixels(
284305
input,
285306
detectionThreshold,
286-
iouThreshold,
307+
keypointThreshold,
287308
methodName
288309
);
289310

290-
const entries = Object.entries(this.keypointMap) as [string, number][];
291-
return raw.map((person) => {
292-
const named: Record<string, Keypoint> = {};
293-
for (const [name, idx] of entries) named[name] = person[idx]!;
294-
return named as PersonKeypoints<ResolveKeypoints<T>>;
295-
});
311+
const entries = Object.entries(this.keypointMap);
312+
return mapPersonKeypoints<ResolveKeypoints<T>>(
313+
raw,
314+
entries,
315+
this.maxKeypointIndex
316+
);
296317
}
297318
}

0 commit comments

Comments
 (0)