Skip to content

Commit 2b11d1b

Browse files
committed
fix: types and revert bg, fg in selfie segmentation
1 parent 2f460be commit 2b11d1b

2 files changed

Lines changed: 12 additions & 9 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/models/image_segmentation/BaseImageSegmentation.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ std::shared_ptr<jsi::Object> BaseImageSegmentation::postprocess(
105105
for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) {
106106
bgPtr[pixel] = 1.0f - fgPtr[pixel];
107107
}
108-
resultClasses.push_back(bg);
109108
resultClasses.push_back(fg);
109+
resultClasses.push_back(bg);
110110
} else {
111111
// Multi-class segmentation (e.g. DeepLab, RF-DETR)
112112
for (std::size_t cl = 0; cl < numChannels; ++cl) {
@@ -121,9 +121,9 @@ std::shared_ptr<jsi::Object> BaseImageSegmentation::postprocess(
121121
auto *argmaxPtr = reinterpret_cast<int32_t *>(argmax->data());
122122

123123
if (numChannels == 1) {
124-
auto *fgPtr = reinterpret_cast<float *>(resultClasses[1]->data());
124+
auto *fgPtr = reinterpret_cast<float *>(resultClasses[0]->data());
125125
for (std::size_t pixel = 0; pixel < outputPixels; ++pixel) {
126-
argmaxPtr[pixel] = (fgPtr[pixel] > 0.5f) ? 1 : 0;
126+
argmaxPtr[pixel] = (fgPtr[pixel] > 0.5f) ? 0 : 1;
127127
}
128128
} else {
129129
std::vector<float> maxLogits(outputPixels,

packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@ import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
1212
import { RnExecutorchError } from '../../errors/errorUtils';
1313
import { BaseModule } from '../BaseModule';
1414

15-
const ModelConfigs: Record<
16-
SegmentationModelName,
17-
SegmentationConfig<LabelEnum>
18-
> = {
15+
const ModelConfigs = {
1916
'deeplab-v3': {
2017
labelMap: DeeplabLabel,
2118
preprocessorConfig: undefined,
@@ -24,7 +21,10 @@ const ModelConfigs: Record<
2421
labelMap: SelfieSegmentationLabel,
2522
preprocessorConfig: undefined,
2623
},
27-
} as const;
24+
} as const satisfies Record<
25+
SegmentationModelName,
26+
SegmentationConfig<LabelEnum>
27+
>;
2828

2929
/** @internal */
3030
type ModelConfigsType = typeof ModelConfigs;
@@ -96,7 +96,10 @@ export class ImageSegmentationModule<
9696
onDownloadProgress: (progress: number) => void = () => {}
9797
): Promise<ImageSegmentationModule<ModelNameOf<C>>> {
9898
const { modelName, modelSource } = config;
99-
const { labelMap, preprocessorConfig } = ModelConfigs[modelName];
99+
const { labelMap } = ModelConfigs[modelName];
100+
const { preprocessorConfig } = ModelConfigs[
101+
modelName
102+
] as SegmentationConfig<LabelEnum>;
100103
const normMean = preprocessorConfig?.normMean ?? [];
101104
const normStd = preprocessorConfig?.normStd ?? [];
102105
const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource);

0 commit comments

Comments
 (0)