Skip to content

Commit 0e9a273

Browse files
authored
feat: Image segmentation for ios (#113)
## Description Add semantic image segmentation with iOS native code. The result of a model run is a map containing masks of per-pixel probability for specified classes and an 'argmax' mask containing the indices of max value labels for each pixel. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent 19e23ac commit 0e9a273

20 files changed

Lines changed: 424 additions & 4 deletions
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#import <RnExecutorchSpec/RnExecutorchSpec.h>
2+
3+
@interface ImageSegmentation : NSObject <NativeImageSegmentationSpec>
4+
5+
@end
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#import "ImageSegmentation.h"
2+
#import "models/image_segmentation/ImageSegmentationModel.h"
3+
#import "models/BaseModel.h"
4+
#import "utils/ETError.h"
5+
#import <ExecutorchLib/ETModel.h>
6+
#import <React/RCTBridgeModule.h>
7+
#import <opencv2/opencv.hpp>
8+
#import "ImageProcessor.h"
9+
10+
@implementation ImageSegmentation {
11+
ImageSegmentationModel *model;
12+
}
13+
14+
RCT_EXPORT_MODULE()
15+
16+
- (void)loadModule:(NSString *)modelSource
17+
resolve:(RCTPromiseResolveBlock)resolve
18+
reject:(RCTPromiseRejectBlock)reject {
19+
20+
model = [[ImageSegmentationModel alloc] init];
21+
[model
22+
loadModel:[NSURL URLWithString:modelSource]
23+
completion:^(BOOL success, NSNumber *errorCode) {
24+
if (success) {
25+
resolve(errorCode);
26+
return;
27+
}
28+
29+
reject(@"init_module_error",
30+
[NSString stringWithFormat:@"%ld", (long)[errorCode longValue]],
31+
nil);
32+
return;
33+
}];
34+
}
35+
36+
- (void)forward:(NSString *)input
37+
classesOfInterest:(NSArray *)classesOfInterest
38+
resize:(BOOL)resize
39+
resolve:(RCTPromiseResolveBlock)resolve
40+
reject:(RCTPromiseRejectBlock)reject {
41+
42+
@try {
43+
cv::Mat image = [ImageProcessor readImage:input];
44+
NSDictionary *result = [model runModel:image
45+
returnClasses:classesOfInterest
46+
resize:resize];
47+
48+
resolve(result);
49+
return;
50+
} @catch (NSException *exception) {
51+
NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason);
52+
reject(@"forward_error",
53+
[NSString stringWithFormat:@"%@", exception.reason], nil);
54+
return;
55+
}
56+
}
57+
58+
- (std::shared_ptr<facebook::react::TurboModule>)getTurboModule:
59+
(const facebook::react::ObjCTurboModule::InitParams &)params {
60+
return std::make_shared<facebook::react::NativeImageSegmentationSpecJSI>(params);
61+
}
62+
63+
@end

ios/RnExecutorch/StyleTransfer.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#import "StyleTransfer.h"
22
#import "ImageProcessor.h"
33
#import "models/BaseModel.h"
4-
#import "models/StyleTransferModel.h"
4+
#import "models/style_transfer/StyleTransferModel.h"
55
#import "utils/ETError.h"
66
#import <ExecutorchLib/ETModel.h>
77
#import <React/RCTBridgeModule.h>

ios/RnExecutorch/models/classification/ClassificationModel.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#import "ClassificationModel.h"
22
#import "../../utils/ImageProcessor.h"
3+
#import "../../utils/Numerical.h"
34
#import "Constants.h"
4-
#import "Utils.h"
55
#import "opencv2/opencv.hpp"
66

77
@implementation ClassificationModel
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#import <string>
2+
#import <vector>
3+
4+
5+
extern const std::vector<std::string> deeplabv3_resnet50_labels;
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#import "Constants.h"
2+
#import <string>
3+
#import <vector>
4+
5+
const std::vector<std::string> deeplabv3_resnet50_labels = {
6+
"BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT",
7+
"BOTTLE", "BUS", "CAR", "CAT", "CHAIR", "COW", "DININGTABLE",
8+
"DOG", "HORSE", "MOTORBIKE", "PERSON", "POTTEDPLANT", "SHEEP",
9+
"SOFA", "TRAIN", "TVMONITOR"
10+
};
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#import "../BaseModel.h"
2+
#import "opencv2/opencv.hpp"
3+
4+
@interface ImageSegmentationModel : BaseModel
5+
- (cv::Size)getModelImageSize;
6+
- (NSDictionary *)runModel:(cv::Mat &)input
7+
returnClasses:(NSArray *)classesOfInterest
8+
resize:(BOOL)resize;
9+
10+
@end
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
#import "ImageSegmentationModel.h"
2+
#import <unordered_set>
3+
#import <algorithm>
4+
#import <vector>
5+
#import "../../utils/ImageProcessor.h"
6+
#import "../../utils/Numerical.h"
7+
#import "../../utils/Conversions.h"
8+
#import "opencv2/opencv.hpp"
9+
#import "Constants.h"
10+
11+
@interface ImageSegmentationModel ()
12+
- (NSArray *)preprocess:(cv::Mat &)input;
13+
- (NSDictionary *)postprocess:(NSArray *)output
14+
returnClasses:(NSArray *)classesOfInterest
15+
resize:(BOOL)resize;
16+
@end
17+
18+
@implementation ImageSegmentationModel {
19+
cv::Size originalSize;
20+
}
21+
22+
- (cv::Size)getModelImageSize {
23+
NSArray *inputShape = [module getInputShape:@0];
24+
NSNumber *widthNumber = inputShape.lastObject;
25+
NSNumber *heightNumber = inputShape[inputShape.count - 2];
26+
27+
int height = [heightNumber intValue];
28+
int width = [widthNumber intValue];
29+
30+
return cv::Size(height, width);
31+
}
32+
33+
- (NSArray *)preprocess:(cv::Mat &)input {
34+
originalSize = cv::Size(input.cols, input.rows);
35+
36+
cv::Size modelImageSize = [self getModelImageSize];
37+
cv::Mat output;
38+
cv::resize(input, output, modelImageSize);
39+
40+
NSArray *modelInput = [ImageProcessor matToNSArray:output];
41+
return modelInput;
42+
}
43+
44+
std::vector<cv::Mat> extractResults(NSArray *result, std::size_t numLabels,
45+
cv::Size modelImageSize, cv::Size originalSize, BOOL resize) {
46+
std::size_t numModelPixels = modelImageSize.height * modelImageSize.width;
47+
48+
std::vector<cv::Mat> resizedLabelScores(numLabels);
49+
for (std::size_t label = 0; label < numLabels; ++label) {
50+
cv::Mat labelMat = cv::Mat(modelImageSize, CV_64F);
51+
52+
for(std::size_t pixel = 0; pixel < numModelPixels; ++pixel){
53+
int row = pixel / modelImageSize.width;
54+
int col = pixel % modelImageSize.width;
55+
labelMat.at<double>(row, col) = [result[label * numModelPixels + pixel] doubleValue];
56+
}
57+
58+
if (resize) {
59+
cv::resize(labelMat, resizedLabelScores[label], originalSize);
60+
}
61+
else {
62+
resizedLabelScores[label] = std::move(labelMat);
63+
}
64+
}
65+
return resizedLabelScores;
66+
}
67+
68+
void adjustScoresPerPixel(std::vector<cv::Mat>& labelScores, cv::Mat& argMax,
69+
cv::Size outputSize, std::size_t numLabels) {
70+
std::size_t numOutputPixels = outputSize.height * outputSize.width;
71+
for (std::size_t pixel = 0; pixel < numOutputPixels; ++pixel) {
72+
int row = pixel / outputSize.width;
73+
int col = pixel % outputSize.width;
74+
std::vector<double> scores;
75+
scores.reserve(numLabels);
76+
for (const auto& mat : labelScores) {
77+
scores.push_back(mat.at<double>(row, col));
78+
}
79+
80+
std::vector<double> adjustedScores = softmax(scores);
81+
82+
for (std::size_t label = 0; label < numLabels; ++label) {
83+
labelScores[label].at<double>(row, col) = adjustedScores[label];
84+
}
85+
86+
auto maxIt = std::max_element(scores.begin(), scores.end());
87+
argMax.at<int>(row, col) = std::distance(scores.begin(), maxIt);
88+
}
89+
}
90+
91+
- (NSDictionary *)postprocess:(NSArray *)output
92+
returnClasses:(NSArray *)classesOfInterest
93+
resize:(BOOL)resize {
94+
cv::Size modelImageSize = [self getModelImageSize];
95+
96+
std::size_t numLabels = deeplabv3_resnet50_labels.size();
97+
98+
NSAssert((std::size_t)output.count == numLabels * modelImageSize.height * modelImageSize.width,
99+
@"Model generated unexpected output size.");
100+
101+
// For each label extract it's matrix,
102+
// and rescale it to the original size if `resize`
103+
std::vector<cv::Mat> resizedLabelScores =
104+
extractResults(output, numLabels, modelImageSize, originalSize, resize);
105+
106+
cv::Size outputSize = resize ? originalSize : modelImageSize;
107+
cv::Mat argMax = cv::Mat(outputSize, CV_32S);
108+
109+
// For each pixel apply softmax across all the labels and calculate the argMax
110+
adjustScoresPerPixel(resizedLabelScores, argMax, outputSize, numLabels);
111+
112+
std::unordered_set<std::string> labelSet;
113+
114+
for (id label in classesOfInterest) {
115+
labelSet.insert(std::string([label UTF8String]));
116+
}
117+
118+
NSMutableDictionary *result = [NSMutableDictionary dictionary];
119+
120+
// Convert to NSArray and populate the final dictionary
121+
for (std::size_t label = 0; label < numLabels; ++label) {
122+
if (labelSet.contains(deeplabv3_resnet50_labels[label])){
123+
NSString *labelString = @(deeplabv3_resnet50_labels[label].c_str());
124+
NSArray *arr = simpleMatToNSArray<double>(resizedLabelScores[label]);
125+
result[labelString] = arr;
126+
}
127+
}
128+
129+
result[@"ARGMAX"] = simpleMatToNSArray<int>(argMax);
130+
131+
return result;
132+
}
133+
134+
- (NSDictionary *)runModel:(cv::Mat &)input
135+
returnClasses:(NSArray *)classesOfInterest
136+
resize:(BOOL)resize {
137+
NSArray *modelInput = [self preprocess:input];
138+
NSArray *result = [self forward:modelInput];
139+
140+
NSDictionary *output = [self postprocess:result[0]
141+
returnClasses:classesOfInterest
142+
resize:resize];
143+
144+
return output;
145+
}
146+
147+
@end

ios/RnExecutorch/models/StyleTransferModel.h renamed to ios/RnExecutorch/models/style_transfer/StyleTransferModel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#import "BaseModel.h"
1+
#import "../BaseModel.h"
22
#import "opencv2/opencv.hpp"
33

44
@interface StyleTransferModel : BaseModel

ios/RnExecutorch/models/StyleTransferModel.mm renamed to ios/RnExecutorch/models/style_transfer/StyleTransferModel.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#import "StyleTransferModel.h"
2-
#import "../utils/ImageProcessor.h"
2+
#import "../../utils/ImageProcessor.h"
33
#import "opencv2/opencv.hpp"
44

55
@implementation StyleTransferModel {

0 commit comments

Comments
 (0)