@@ -30,120 +30,148 @@ BaseInstanceSegmentation::BaseInstanceSegmentation(
3030 }
3131}
3232
33- cv::Mat BaseInstanceSegmentation::processMaskFromLogits (
34- const cv::Mat &logitsMat, float x1, float y1, float x2, float y2,
35- cv::Size modelInputSize, cv::Size originalSize, int32_t maskW,
36- int32_t maskH, int32_t bboxW, int32_t bboxH, float origX1, float origY1,
37- bool warpToOriginal, int32_t &outWidth, int32_t &outHeight) {
33+ std::tuple<utils::computer_vision::BBox, float , int32_t >
34+ BaseInstanceSegmentation::extractDetectionData (const float *bboxData,
35+ const float *scoresData,
36+ int32_t index) {
37+ utils::computer_vision::BBox bbox{
38+ bboxData[index * 4 ], bboxData[index * 4 + 1 ], bboxData[index * 4 + 2 ],
39+ bboxData[index * 4 + 3 ]};
40+ float score = scoresData[index * 2 ];
41+ int32_t label = static_cast <int32_t >(scoresData[index * 2 + 1 ]);
42+
43+ return {bbox, score, label};
44+ }
3845
39- float mx1F = x1 * maskW / modelInputSize.width ;
40- float my1F = y1 * maskH / modelInputSize.height ;
41- float mx2F = x2 * maskW / modelInputSize.width ;
42- float my2F = y2 * maskH / modelInputSize.height ;
46+ cv::Rect BaseInstanceSegmentation::computeMaskCropRect (
47+ const utils::computer_vision::BBox &bboxModel, cv::Size modelInputSize,
48+ cv::Size maskSize) {
49+
50+ float mx1F = bboxModel.x1 * maskSize.width / modelInputSize.width ;
51+ float my1F = bboxModel.y1 * maskSize.height / modelInputSize.height ;
52+ float mx2F = bboxModel.x2 * maskSize.width / modelInputSize.width ;
53+ float my2F = bboxModel.y2 * maskSize.height / modelInputSize.height ;
4354
4455 int32_t mx1 = std::max (0 , static_cast <int32_t >(std::floor (mx1F)));
4556 int32_t my1 = std::max (0 , static_cast <int32_t >(std::floor (my1F)));
46- int32_t mx2 = std::min (maskW, static_cast <int32_t >(std::ceil (mx2F)));
47- int32_t my2 = std::min (maskH, static_cast <int32_t >(std::ceil (my2F)));
57+ int32_t mx2 = std::min (maskSize.width , static_cast <int32_t >(std::ceil (mx2F)));
58+ int32_t my2 =
59+ std::min (maskSize.height , static_cast <int32_t >(std::ceil (my2F)));
60+
61+ return cv::Rect (mx1, my1, mx2 - mx1, my2 - my1);
62+ }
63+
64+ cv::Rect BaseInstanceSegmentation::addPaddingToRect (const cv::Rect &rect,
65+ cv::Size maskSize) {
66+ int32_t x1 = std::max (0 , rect.x - 1 );
67+ int32_t y1 = std::max (0 , rect.y - 1 );
68+ int32_t x2 = std::min (maskSize.width , rect.x + rect.width + 1 );
69+ int32_t y2 = std::min (maskSize.height , rect.y + rect.height + 1 );
70+
71+ return cv::Rect (x1, y1, x2 - x1, y2 - y1);
72+ }
4873
49- cv::Mat finalBinaryMat;
50- outWidth = bboxW;
51- outHeight = bboxH;
74+ cv::Mat BaseInstanceSegmentation::applySigmoid (const cv::Mat &logits) {
75+ cv::Mat probMat;
76+ cv::exp (-logits, probMat);
77+ probMat = 255 .0f / (1 .0f + probMat);
78+ probMat.convertTo (probMat, CV_8UC1);
79+ return probMat;
80+ }
81+
82+ cv::Mat BaseInstanceSegmentation::warpToOriginalResolution (
83+ const cv::Mat &probMat, const cv::Rect &maskRect, cv::Size originalSize,
84+ cv::Size maskSize, const utils::computer_vision::BBox &bboxOriginal) {
85+
86+ float scaleX = static_cast <float >(originalSize.width ) / maskSize.width ;
87+ float scaleY = static_cast <float >(originalSize.height ) / maskSize.height ;
88+
89+ cv::Mat M = (cv::Mat_<float >(2 , 3 ) << scaleX, 0 ,
90+ (maskRect.x * scaleX - bboxOriginal.x1 ), 0 , scaleY,
91+ (maskRect.y * scaleY - bboxOriginal.y1 ));
92+
93+ cv::Size bboxSize (static_cast <int32_t >(std::round (bboxOriginal.width ())),
94+ static_cast <int32_t >(std::round (bboxOriginal.height ())));
95+
96+ cv::Mat warped;
97+ cv::warpAffine (probMat, warped, M, bboxSize, cv::INTER_LINEAR);
98+ return warped;
99+ }
100+
101+ cv::Mat BaseInstanceSegmentation::thresholdToBinary (const cv::Mat &probMat) {
102+ cv::Mat binary;
103+ cv::threshold (probMat, binary, 127 , 1 , cv::THRESH_BINARY);
104+ return binary;
105+ }
106+
107+ cv::Mat BaseInstanceSegmentation::processMaskFromLogits (
108+ const cv::Mat &logitsMat, const utils::computer_vision::BBox &bboxModel,
109+ const utils::computer_vision::BBox &bboxOriginal, cv::Size modelInputSize,
110+ cv::Size originalSize, cv::Size maskSize, bool warpToOriginal,
111+ cv::Size &outSize) {
112+
113+ cv::Rect cropRect = computeMaskCropRect (bboxModel, modelInputSize, maskSize);
114+
115+ if (warpToOriginal) {
116+ cropRect = addPaddingToRect (cropRect, maskSize);
117+ }
118+
119+ cv::Mat cropped = logitsMat (cropRect);
120+ cv::Mat probMat = applySigmoid (cropped);
52121
53122 if (warpToOriginal) {
54- int32_t pmx1 = std::max (0 , mx1 - 1 );
55- int32_t pmy1 = std::max (0 , my1 - 1 );
56- int32_t pmx2 = std::min (maskW, mx2 + 1 );
57- int32_t pmy2 = std::min (maskH, my2 + 1 );
58-
59- cv::Mat croppedLogits =
60- logitsMat (cv::Rect (pmx1, pmy1, pmx2 - pmx1, pmy2 - pmy1));
61- cv::Mat probMat;
62- cv::exp (-croppedLogits, probMat);
63- probMat = 255 .0f / (1 .0f + probMat);
64- probMat.convertTo (probMat, CV_8UC1);
65-
66- float maskToOrigX = static_cast <float >(originalSize.width ) / maskW;
67- float maskToOrigY = static_cast <float >(originalSize.height ) / maskH;
68-
69- cv::Mat M =
70- (cv::Mat_<float >(2 , 3 ) << maskToOrigX, 0 , (pmx1 * maskToOrigX - origX1),
71- 0 , maskToOrigY, (pmy1 * maskToOrigY - origY1));
72-
73- cv::Mat warpedMat;
74- cv::warpAffine (probMat, warpedMat, M, cv::Size (bboxW, bboxH),
75- cv::INTER_LINEAR);
76-
77- cv::threshold (warpedMat, finalBinaryMat, 127 , 1 , cv::THRESH_BINARY);
78- } else {
79- cv::Mat croppedLogits = logitsMat (cv::Rect (mx1, my1, mx2 - mx1, my2 - my1));
80- cv::Mat probMat;
81- cv::exp (-croppedLogits, probMat);
82- probMat = 255 .0f / (1 .0f + probMat);
83- probMat.convertTo (probMat, CV_8UC1);
84-
85- cv::threshold (probMat, finalBinaryMat, 127 , 1 , cv::THRESH_BINARY);
86- outWidth = finalBinaryMat.cols ;
87- outHeight = finalBinaryMat.rows ;
123+ probMat = warpToOriginalResolution (probMat, cropRect, originalSize,
124+ maskSize, bboxOriginal);
88125 }
126+ cv::Mat binaryMask = thresholdToBinary (probMat);
127+ outSize = cv::Size (binaryMask.cols , binaryMask.rows );
89128
90- return finalBinaryMat ;
129+ return binaryMask ;
91130}
92131
93132std::optional<types::Instance> BaseInstanceSegmentation::processDetection (
94133 int32_t detectionIndex, const float *bboxData, const float *scoresData,
95- const float *maskData, int32_t maskH, int32_t maskW,
96- cv::Size modelInputSize, cv::Size originalSize, float widthRatio,
97- float heightRatio, double confidenceThreshold,
134+ const cv::Mat &logitsMat, cv::Size modelInputSize, cv::Size originalSize,
135+ float widthRatio, float heightRatio, double confidenceThreshold,
98136 const std::set<int32_t > &allowedClasses,
99137 bool returnMaskAtOriginalResolution) {
100138
101- int32_t i = detectionIndex;
102-
103- float x1 = bboxData[i * 4 + 0 ];
104- float y1 = bboxData[i * 4 + 1 ];
105- float x2 = bboxData[i * 4 + 2 ];
106- float y2 = bboxData[i * 4 + 3 ];
107- float score = scoresData[i * 2 + 0 ];
108- auto labelIdx = static_cast <std::size_t >(scoresData[i * 2 + 1 ]);
139+ // Extract detection data
140+ auto [bboxModel, score, labelIdx] =
141+ extractDetectionData (bboxData, scoresData, detectionIndex);
109142
143+ // Filter by confidence
110144 if (score < confidenceThreshold) {
111145 return std::nullopt ;
112146 }
113- if (!allowedClasses.empty () && allowedClasses.find (static_cast <int32_t >(
114- labelIdx)) == allowedClasses.end ()) {
147+
148+ // Filter by class
149+ if (!allowedClasses.empty () &&
150+ allowedClasses.find (labelIdx) == allowedClasses.end ()) {
115151 return std::nullopt ;
116152 }
117153
118154 // Scale bbox to original image coordinates
119- float origX1 = x1 * widthRatio;
120- float origY1 = y1 * heightRatio;
121- float origX2 = x2 * widthRatio;
122- float origY2 = y2 * heightRatio;
155+ utils::computer_vision::BBox bboxOriginal =
156+ bboxModel.scale (widthRatio, heightRatio);
123157
124- int32_t bboxW = static_cast <int32_t >(std::round (origX2 - origX1));
125- int32_t bboxH = static_cast <int32_t >(std::round (origY2 - origY1));
126-
127- if (bboxW <= 0 || bboxH <= 0 ) {
158+ if (!bboxOriginal.isValid ()) {
128159 return std::nullopt ;
129160 }
130161
131- const float *logits = maskData + (i * maskH * maskW);
132- cv::Mat logitsMat (maskH, maskW, CV_32FC1, const_cast <float *>(logits));
133-
134- int32_t finalMaskWidth, finalMaskHeight;
162+ // Process mask
163+ cv::Size maskSize (logitsMat.cols , logitsMat.rows );
164+ cv::Size outSize;
135165 cv::Mat finalBinaryMat = processMaskFromLogits (
136- logitsMat, x1, y1, x2, y2, modelInputSize, originalSize, maskW, maskH,
137- bboxW, bboxH, origX1, origY1, returnMaskAtOriginalResolution,
138- finalMaskWidth, finalMaskHeight);
166+ logitsMat, bboxModel, bboxOriginal, modelInputSize, originalSize,
167+ maskSize, returnMaskAtOriginalResolution, outSize);
139168
169+ // Create instance
140170 std::vector<uint8_t > finalMask (finalBinaryMat.data ,
141171 finalBinaryMat.data + finalBinaryMat.total ());
142172
143- return types::Instance (
144- utils::computer_vision::BBox{origX1, origY1, origX2, origY2},
145- std::move (finalMask), finalMaskWidth, finalMaskHeight,
146- static_cast <int32_t >(labelIdx), score, i);
173+ return types::Instance (bboxOriginal, std::move (finalMask), outSize.width ,
174+ outSize.height , labelIdx, score, detectionIndex);
147175}
148176
149177std::vector<types::Instance> BaseInstanceSegmentation::postprocess (
@@ -203,10 +231,14 @@ std::vector<types::Instance> BaseInstanceSegmentation::postprocess(
203231 int32_t processed = 0 ;
204232
205233 for (int32_t i = 0 ; i < N; ++i) {
234+ // Extract mask logits for this detection
235+ const float *logits = maskData + (i * maskH * maskW);
236+ cv::Mat logitsMat (maskH, maskW, CV_32FC1, const_cast <float *>(logits));
237+
206238 auto instance = processDetection (
207- i, bboxData, scoresData, maskData, maskH, maskW, modelInputSize ,
208- originalSize, widthRatio, heightRatio, confidenceThreshold,
209- allowedClasses, returnMaskAtOriginalResolution);
239+ i, bboxData, scoresData, logitsMat, modelInputSize, originalSize ,
240+ widthRatio, heightRatio, confidenceThreshold, allowedClasses ,
241+ returnMaskAtOriginalResolution);
210242
211243 if (instance.has_value ()) {
212244 instances.push_back (std::move (*instance));
0 commit comments