@@ -32,7 +32,8 @@ PoseEstimation::PoseEstimation(const std::string &modelSource,
3232
3333PoseDetections PoseEstimation::postprocess (const std::vector<EValue> &tensors,
3434 cv::Size originalSize,
35- double detectionThreshold) {
35+ double detectionThreshold,
36+ double keypointThreshold) {
3637 // Output tensors (batch dim squeezed):
3738 // 0: boxes (Q, 4) - xyxy bbox in model input pixel space
3839 // 1: scores (Q,) - person confidence [0, 1]
@@ -75,6 +76,11 @@ PoseDetections PoseEstimation::postprocess(const std::vector<EValue> &tensors,
7576 const float *detectionKps = kpData + i * numKeypoints * 3 ;
7677
7778 for (size_t k = 0 ; k < numKeypoints; ++k) {
79+ float visibility = detectionKps[k * 3 + 2 ];
80+ if (visibility < keypointThreshold) {
81+ keypoints.emplace_back (-1 , -1 );
82+ continue ;
83+ }
7884 float x = detectionKps[k * 3 ];
7985 float y = detectionKps[k * 3 + 1 ];
8086
@@ -92,7 +98,7 @@ PoseDetections PoseEstimation::postprocess(const std::vector<EValue> &tensors,
9298
9399PoseDetections PoseEstimation::runInference (cv::Mat image,
94100 double detectionThreshold,
95- double iouThreshold ,
101+ double keypointThreshold ,
96102 const std::string &methodName) {
97103
98104 log (LOG_LEVEL::Debug, " Running inference with model name: " + methodName);
@@ -101,9 +107,9 @@ PoseDetections PoseEstimation::runInference(cv::Mat image,
101107 throw RnExecutorchError (RnExecutorchErrorCode::InvalidUserInput,
102108 " detectionThreshold must be in range [0, 1]" );
103109 }
104- if (iouThreshold < 0.0 || iouThreshold > 1.0 ) {
110+ if (keypointThreshold < 0.0 || keypointThreshold > 1.0 ) {
105111 throw RnExecutorchError (RnExecutorchErrorCode::InvalidUserInput,
106- " iouThreshold must be in range [0, 1]" );
112+ " keypointThreshold must be in range [0, 1]" );
107113 }
108114
109115 std::scoped_lock lock (inference_mutex_);
@@ -132,30 +138,31 @@ PoseDetections PoseEstimation::runInference(cv::Mat image,
132138 " Ensure the model input is correct." );
133139 }
134140
135- return postprocess (executeResult.get (), originalSize, detectionThreshold);
141+ return postprocess (executeResult.get (), originalSize, detectionThreshold,
142+ keypointThreshold);
136143}
137144
138145PoseDetections PoseEstimation::generateFromString (std::string imageSource,
139146 double detectionThreshold,
140- double iouThreshold ,
147+ double keypointThreshold ,
141148 std::string methodName) {
142149 cv::Mat imageBGR = image_processing::readImage (imageSource);
143150 cv::Mat imageRGB;
144151 cv::cvtColor (imageBGR, imageRGB, cv::COLOR_BGR2RGB);
145- return runInference (std::move (imageRGB), detectionThreshold, iouThreshold,
146- methodName);
152+ return runInference (std::move (imageRGB), detectionThreshold,
153+ keypointThreshold, methodName);
147154}
148155
149156PoseDetections PoseEstimation::generateFromFrame (jsi::Runtime &runtime,
150157 const jsi::Value &frameData,
151158 double detectionThreshold,
152- double iouThreshold ,
159+ double keypointThreshold ,
153160 std::string methodName) {
154161 auto orient = ::rnexecutorch::utils::readFrameOrientation (runtime, frameData);
155162 cv::Mat frame = extractFromFrame (runtime, frameData);
156163 cv::Mat rotated = ::rnexecutorch::utils::rotateFrameForModel (frame, orient);
157164 auto detections =
158- runInference (rotated, detectionThreshold, iouThreshold , methodName);
165+ runInference (rotated, detectionThreshold, keypointThreshold , methodName);
159166 for (auto &person : detections) {
160167 ::rnexecutorch::utils::inverseRotatePoints (person, orient, rotated.size());
161168 }
@@ -164,10 +171,10 @@ PoseDetections PoseEstimation::generateFromFrame(jsi::Runtime &runtime,
164171
165172PoseDetections PoseEstimation::generateFromPixels (JSTensorViewIn pixelData,
166173 double detectionThreshold,
167- double iouThreshold ,
174+ double keypointThreshold ,
168175 std::string methodName) {
169176 cv::Mat image = extractFromPixels (pixelData);
170- return runInference (image, detectionThreshold, iouThreshold , methodName);
177+ return runInference (image, detectionThreshold, keypointThreshold , methodName);
171178}
172179
173180} // namespace rnexecutorch::models::pose_estimation
0 commit comments