Skip to content

Commit ec04754

Browse files
committed
chore: review changes
1 parent ae91c0c commit ec04754

1 file changed

Lines changed: 51 additions & 36 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <executorch/extension/tensor/tensor.h>
77
#include <rnexecutorch/Error.h>
88
#include <rnexecutorch/data_processing/ImageProcessing.h>
9-
#include <rnexecutorch/data_processing/Numerical.h>
109
#include <rnexecutorch/models/BaseModel.h>
1110

1211
namespace rnexecutorch::models::image_segmentation {
@@ -92,56 +91,72 @@ std::shared_ptr<jsi::Object> BaseImageSegmentation::postprocess(
9291
std::size_t outputPixels = outputH * outputW;
9392
cv::Size outputSize(outputW, outputH);
9493

95-
// Work with vectors, only wrap into OwningArrayBuffer at the end
96-
std::vector<std::vector<float>> classBuffers;
97-
std::vector<int32_t> argmaxData(outputPixels);
94+
// Copy class data directly into OwningArrayBuffers (single copy from span)
95+
std::vector<std::shared_ptr<OwningArrayBuffer>> resultClasses;
96+
resultClasses.reserve(numChannels);
9897

9998
if (numChannels == 1) {
10099
// Binary segmentation (e.g. selfie segmentation)
101-
std::vector<float> bg(outputPixels);
102-
std::vector<float> fg(outputPixels);
100+
auto fg = std::make_shared<OwningArrayBuffer>(resultData.data(),
101+
outputPixels * sizeof(float));
102+
auto bg = std::make_shared<OwningArrayBuffer>(outputPixels * sizeof(float));
103+
auto *fgPtr = reinterpret_cast<float *>(fg->data());
104+
auto *bgPtr = reinterpret_cast<float *>(bg->data());
103105
for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) {
104-
float p = resultData[pixel];
105-
bg[pixel] = 1.0f - p;
106-
fg[pixel] = p;
107-
argmaxData[pixel] = (p > 0.5f) ? 1 : 0;
106+
bgPtr[pixel] = 1.0f - fgPtr[pixel];
108107
}
109-
classBuffers = {std::move(bg), std::move(fg)};
108+
resultClasses.push_back(bg);
109+
resultClasses.push_back(fg);
110110
} else {
111111
// Multi-class segmentation (e.g. DeepLab, RF-DETR)
112-
classBuffers.resize(numChannels);
113112
for (std::size_t cl = 0; cl < numChannels; ++cl) {
114-
classBuffers[cl].assign(resultData.data() + cl * outputPixels,
115-
resultData.data() + (cl + 1) * outputPixels);
113+
resultClasses.push_back(std::make_shared<OwningArrayBuffer>(
114+
resultData.data() + cl * outputPixels, outputPixels * sizeof(float)));
116115
}
116+
}
117+
118+
// Softmax + argmax in class-major order
119+
auto argmax =
120+
std::make_shared<OwningArrayBuffer>(outputPixels * sizeof(int32_t));
121+
auto *argmaxPtr = reinterpret_cast<int32_t *>(argmax->data());
117122

118-
// Apply softmax and compute argmax per pixel
123+
if (numChannels == 1) {
124+
auto *fgPtr = reinterpret_cast<float *>(resultClasses[1]->data());
119125
for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) {
120-
std::vector<float> values(numChannels);
121-
for (std::size_t cl = 0; cl < numChannels; ++cl) {
122-
values[cl] = classBuffers[cl][pixel];
123-
}
124-
numerical::softmax(values);
125-
126-
float maxVal = values[0];
127-
int maxInd = 0;
128-
for (std::size_t cl = 0; cl < numChannels; ++cl) {
129-
classBuffers[cl][pixel] = values[cl];
130-
if (values[cl] > maxVal) {
131-
maxVal = values[cl];
132-
maxInd = static_cast<int>(cl);
126+
argmaxPtr[pixel] = (fgPtr[pixel] > 0.5f) ? 1 : 0;
127+
}
128+
} else {
129+
std::vector<float> maxLogits(outputPixels,
130+
-std::numeric_limits<float>::infinity());
131+
std::vector<float> sumExp(outputPixels, 0.0f);
132+
133+
// Pass 1: find per-pixel max and argmax
134+
for (std::size_t cl = 0; cl < numChannels; ++cl) {
135+
auto *clPtr = reinterpret_cast<float *>(resultClasses[cl]->data());
136+
for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) {
137+
if (clPtr[pixel] > maxLogits[pixel]) {
138+
maxLogits[pixel] = clPtr[pixel];
139+
argmaxPtr[pixel] = static_cast<int32_t>(cl);
133140
}
134141
}
135-
argmaxData[pixel] = maxInd;
136142
}
137-
}
138143

139-
// Wrap into OwningArrayBuffers
140-
auto argmax = std::make_shared<OwningArrayBuffer>(argmaxData);
141-
std::vector<std::shared_ptr<OwningArrayBuffer>> resultClasses;
142-
resultClasses.reserve(classBuffers.size());
143-
for (auto &buf : classBuffers) {
144-
resultClasses.push_back(std::make_shared<OwningArrayBuffer>(buf));
144+
// Pass 2: subtract max, exp, accumulate sum
145+
for (std::size_t cl = 0; cl < numChannels; ++cl) {
146+
auto *clPtr = reinterpret_cast<float *>(resultClasses[cl]->data());
147+
for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) {
148+
clPtr[pixel] = std::exp(clPtr[pixel] - maxLogits[pixel]);
149+
sumExp[pixel] += clPtr[pixel];
150+
}
151+
}
152+
153+
// Pass 3: normalize by sum
154+
for (std::size_t cl = 0; cl < numChannels; ++cl) {
155+
auto *clPtr = reinterpret_cast<float *>(resultClasses[cl]->data());
156+
for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) {
157+
clPtr[pixel] /= sumExp[pixel];
158+
}
159+
}
145160
}
146161

147162
// Filter classes of interest

0 commit comments

Comments
 (0)