Skip to content

Commit aea9c26

Browse files
feat: unify frame extraction and preprocessing
1 parent 4c9f64f commit aea9c26

7 files changed

Lines changed: 1792 additions & 1782 deletions

File tree

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "VisionModel.h"
2+
#include <rnexecutorch/utils/FrameProcessor.h>
3+
4+
namespace rnexecutorch {
5+
namespace models {
6+
7+
using namespace facebook;
8+
9+
cv::Mat VisionModel::extractAndPreprocess(jsi::Runtime &runtime,
10+
const jsi::Value &frameData) const {
11+
// Extract frame using FrameProcessor utility
12+
auto frameObj = frameData.asObject(runtime);
13+
cv::Mat frame = utils::FrameProcessor::extractFrame(runtime, frameObj);
14+
15+
// Apply model-specific preprocessing
16+
return preprocessFrame(frame);
17+
}
18+
19+
} // namespace models
20+
} // namespace rnexecutorch
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#pragma once
2+
3+
#include <jsi/jsi.h>
4+
#include <mutex>
5+
#include <opencv2/opencv.hpp>
6+
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
7+
#include <rnexecutorch/models/BaseModel.h>
8+
9+
namespace rnexecutorch {
10+
namespace models {
11+
12+
/**
13+
* @brief Base class for computer vision models that support real-time camera
14+
* input
15+
*
16+
* VisionModel extends BaseModel with thread-safe inference and automatic frame
17+
* extraction from VisionCamera. This class is designed for models that need to
18+
* process camera frames in real-time (e.g., at 30fps).
19+
*
20+
* Thread Safety:
21+
* - All inference operations are protected by a mutex
22+
* - generateFromFrame() uses try_lock() to skip frames when the model is busy
23+
* - This prevents blocking the camera thread and maintains smooth frame rates
24+
*
25+
* Usage:
26+
* Subclasses should:
27+
* 1. Inherit from VisionModel instead of BaseModel
28+
* 2. Implement preprocessFrame() with model-specific preprocessing
29+
* 3. Use inference_mutex_ when calling forward() in custom generate methods
30+
* 4. Use lock_guard for blocking operations (JS API)
31+
* 5. Use try_lock() for non-blocking operations (camera API)
32+
*
33+
* Example:
34+
* @code
35+
* class Classification : public VisionModel {
36+
* public:
37+
* std::unordered_map<std::string_view, float>
38+
* generateFromFrame(jsi::Runtime& runtime, const jsi::Value& frameValue) {
39+
* // try_lock is handled automatically
40+
* auto frameObject = frameValue.asObject(runtime);
41+
* cv::Mat frame = FrameExtractor::extractFrame(runtime, frameObject);
42+
*
43+
* // Lock before inference
44+
* if (!inference_mutex_.try_lock()) {
45+
* return {}; // Skip frame if busy
46+
* }
47+
* std::lock_guard<std::mutex> lock(inference_mutex_, std::adopt_lock);
48+
*
49+
* auto preprocessed = preprocessFrame(frame);
50+
* // ... run inference
51+
* }
52+
* };
53+
* @endcode
54+
*/
55+
class VisionModel : public BaseModel {
56+
public:
57+
/**
58+
* @brief Construct a VisionModel with the same parameters as BaseModel
59+
*
60+
* VisionModel uses the same construction pattern as BaseModel, just adding
61+
* thread-safety on top.
62+
*/
63+
VisionModel(const std::string &modelSource,
64+
std::shared_ptr<react::CallInvoker> callInvoker)
65+
: BaseModel(modelSource, callInvoker) {}
66+
67+
/**
68+
* @brief Virtual destructor for proper cleanup in derived classes
69+
*/
70+
virtual ~VisionModel() = default;
71+
72+
protected:
73+
/**
74+
* @brief Mutex to ensure thread-safe inference
75+
*
76+
* This mutex protects against race conditions when:
77+
* - generateFromFrame() is called from VisionCamera worklet thread (30fps)
78+
* - generate() is called from JavaScript thread simultaneously
79+
*
80+
* Usage guidelines:
81+
* - Use std::lock_guard for blocking operations (JS API can wait)
82+
* - Use try_lock() for non-blocking operations (camera should skip frames)
83+
*
84+
* @note Marked mutable to allow locking in const methods if needed
85+
*/
86+
mutable std::mutex inference_mutex_;
87+
88+
/**
89+
* @brief Preprocess a camera frame for model input
90+
*
91+
* This method should implement model-specific preprocessing such as:
92+
* - Resizing to the model's expected input size
93+
* - Color space conversion (e.g., BGR to RGB)
94+
* - Normalization
95+
* - Any other model-specific transformations
96+
*
97+
* @param frame Input frame from camera (already extracted and rotated by
98+
* FrameExtractor)
99+
* @return Preprocessed cv::Mat ready for tensor conversion
100+
*
101+
* @note The input frame is already in RGB format and rotated 90° clockwise
102+
* @note This method is called under mutex protection in generateFromFrame()
103+
*/
104+
virtual cv::Mat preprocessFrame(const cv::Mat &frame) const = 0;
105+
106+
/**
107+
* @brief Extract and preprocess frame from VisionCamera in one call
108+
*
109+
* This is a convenience method that combines frame extraction and
110+
* preprocessing. It handles both nativeBuffer (zero-copy) and ArrayBuffer
111+
* paths automatically.
112+
*
113+
* @param runtime JSI runtime
114+
* @param frameData JSI value containing frame data from VisionCamera
115+
*
116+
* @return Preprocessed cv::Mat ready for tensor conversion
117+
*
118+
* @throws std::runtime_error if frame extraction fails
119+
*
120+
* @note This method does NOT acquire the inference mutex - caller is
121+
* responsible
122+
* @note Typical usage:
123+
* @code
124+
* cv::Mat preprocessed = extractAndPreprocess(runtime, frameData);
125+
* auto tensor = image_processing::getTensorFromMatrix(dims, preprocessed);
126+
* @endcode
127+
*/
128+
cv::Mat extractAndPreprocess(jsi::Runtime &runtime,
129+
const jsi::Value &frameData) const;
130+
};
131+
132+
} // namespace models
133+
// Register VisionModel constructor traits
134+
// Even though VisionModel is abstract, the metaprogramming system needs to know
135+
// its constructor signature for derived classes
136+
REGISTER_CONSTRUCTOR(models::VisionModel, std::string,
137+
std::shared_ptr<react::CallInvoker>);
138+
139+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ Classification::postprocess(const Tensor &tensor) {
7373
return probs;
7474
}
7575

76-
} // namespace rnexecutorch::models::classification
76+
} // namespace rnexecutorch::models::classification
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
#include "FrameProcessor.h"
2+
#include "FrameExtractor.h"
3+
#include <rnexecutorch/Log.h>
4+
#include <stdexcept>
5+
6+
namespace rnexecutorch {
7+
namespace utils {
8+
9+
cv::Mat FrameProcessor::extractFrame(jsi::Runtime &runtime,
10+
const jsi::Object &frameData) {
11+
// Get frame dimensions
12+
int width =
13+
static_cast<int>(frameData.getProperty(runtime, "width").asNumber());
14+
int height =
15+
static_cast<int>(frameData.getProperty(runtime, "height").asNumber());
16+
17+
// Try zero-copy path first (nativeBuffer)
18+
if (hasNativeBuffer(runtime, frameData)) {
19+
static bool loggedPath = false;
20+
if (!loggedPath) {
21+
log(LOG_LEVEL::Debug, "FrameProcessor: Using zero-copy nativeBuffer");
22+
loggedPath = true;
23+
}
24+
25+
try {
26+
return extractFromNativeBuffer(runtime, frameData, width, height);
27+
} catch (const std::exception &e) {
28+
log(LOG_LEVEL::Debug,
29+
"FrameProcessor: nativeBuffer extraction failed: ", e.what());
30+
log(LOG_LEVEL::Debug, "FrameProcessor: Falling back to ArrayBuffer");
31+
}
32+
}
33+
34+
// Fallback to ArrayBuffer path (with copy)
35+
if (frameData.hasProperty(runtime, "data")) {
36+
static bool loggedPath = false;
37+
if (!loggedPath) {
38+
log(LOG_LEVEL::Debug, "FrameProcessor: Using ArrayBuffer (with copy)");
39+
loggedPath = true;
40+
}
41+
42+
return extractFromArrayBuffer(runtime, frameData, width, height);
43+
}
44+
45+
// No valid frame data source
46+
throw std::runtime_error(
47+
"FrameProcessor: No valid frame data (neither nativeBuffer nor data "
48+
"property found)");
49+
}
50+
51+
cv::Size FrameProcessor::getFrameSize(jsi::Runtime &runtime,
52+
const jsi::Object &frameData) {
53+
if (!frameData.hasProperty(runtime, "width") ||
54+
!frameData.hasProperty(runtime, "height")) {
55+
throw std::runtime_error("FrameProcessor: Frame data missing width or "
56+
"height property");
57+
}
58+
59+
int width =
60+
static_cast<int>(frameData.getProperty(runtime, "width").asNumber());
61+
int height =
62+
static_cast<int>(frameData.getProperty(runtime, "height").asNumber());
63+
64+
return cv::Size(width, height);
65+
}
66+
67+
bool FrameProcessor::hasNativeBuffer(jsi::Runtime &runtime,
68+
const jsi::Object &frameData) {
69+
return frameData.hasProperty(runtime, "nativeBuffer");
70+
}
71+
72+
cv::Mat FrameProcessor::extractFromNativeBuffer(jsi::Runtime &runtime,
73+
const jsi::Object &frameData,
74+
int width, int height) {
75+
auto nativeBufferValue = frameData.getProperty(runtime, "nativeBuffer");
76+
77+
// Handle bigint pointer value from JavaScript
78+
uint64_t bufferPtr = static_cast<uint64_t>(
79+
nativeBufferValue.asBigInt(runtime).asUint64(runtime));
80+
81+
// Use FrameExtractor to get cv::Mat from platform-specific buffer
82+
cv::Mat frame = FrameExtractor::extractFromNativeBuffer(bufferPtr);
83+
84+
// Validate extracted frame dimensions match expected
85+
if (frame.cols != width || frame.rows != height) {
86+
log(LOG_LEVEL::Debug, "FrameProcessor: Dimension mismatch - expected ",
87+
width, "x", height, " but got ", frame.cols, "x", frame.rows);
88+
}
89+
90+
return frame;
91+
}
92+
93+
cv::Mat FrameProcessor::extractFromArrayBuffer(jsi::Runtime &runtime,
94+
const jsi::Object &frameData,
95+
int width, int height) {
96+
auto pixelData = frameData.getProperty(runtime, "data");
97+
auto arrayBuffer = pixelData.asObject(runtime).getArrayBuffer(runtime);
98+
uint8_t *data = arrayBuffer.data(runtime);
99+
size_t bufferSize = arrayBuffer.size(runtime);
100+
101+
// Determine format based on buffer size
102+
size_t stride = bufferSize / height;
103+
size_t expectedRGBAStride = width * 4;
104+
size_t expectedRGBStride = width * 3;
105+
106+
cv::Mat frame;
107+
108+
if (stride == expectedRGBAStride || bufferSize >= width * height * 4) {
109+
// RGBA format with potential padding
110+
frame = cv::Mat(height, width, CV_8UC4, data, stride);
111+
112+
static bool loggedFormat = false;
113+
if (!loggedFormat) {
114+
log(LOG_LEVEL::Debug,
115+
"FrameProcessor: ArrayBuffer format is RGBA, "
116+
"stride: ",
117+
stride);
118+
loggedFormat = true;
119+
}
120+
} else if (stride >= expectedRGBStride) {
121+
// RGB format
122+
frame = cv::Mat(height, width, CV_8UC3, data, stride);
123+
124+
static bool loggedFormat = false;
125+
if (!loggedFormat) {
126+
log(LOG_LEVEL::Debug,
127+
"FrameProcessor: ArrayBuffer format is RGB, stride: ", stride);
128+
loggedFormat = true;
129+
}
130+
} else {
131+
throw std::runtime_error(
132+
"FrameProcessor: Unexpected buffer size - expected " +
133+
std::to_string(expectedRGBStride) + " or " +
134+
std::to_string(expectedRGBAStride) + " bytes per row, got " +
135+
std::to_string(stride));
136+
}
137+
138+
return frame;
139+
}
140+
141+
} // namespace utils
142+
} // namespace rnexecutorch

0 commit comments

Comments
 (0)