Skip to content

Commit bd0fa30

Browse files
committed
Remove redundant postprocessor config type
1 parent 3170ee5 commit bd0fa30

File tree

4 files changed

+66
-200
lines changed

4 files changed

+66
-200
lines changed

packages/react-native-executorch/src/hooks/computer_vision/useInstanceSegmentation.ts

Lines changed: 14 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { useState, useEffect } from 'react';
21
import {
32
InstanceSegmentationModule,
43
InstanceSegmentationLabels,
@@ -8,10 +7,8 @@ import {
87
InstanceSegmentationType,
98
InstanceModelNameOf,
109
InstanceSegmentationModelSources,
11-
InstanceSegmentationOptions,
1210
} from '../../types/instanceSegmentation';
13-
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
14-
import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils';
11+
import { useModuleFactory } from '../useModuleFactory';
1512

1613
/**
1714
* React hook for managing an Instance Segmentation model instance.
@@ -47,78 +44,20 @@ export const useInstanceSegmentation = <
4744
>({
4845
model,
4946
preventLoad = false,
50-
}: InstanceSegmentationProps<C>): InstanceSegmentationType<LabelEnum> => {
51-
const [error, setError] = useState<RnExecutorchError | null>(null);
52-
const [isReady, setIsReady] = useState(false);
53-
const [isGenerating, setIsGenerating] = useState(false);
54-
const [downloadProgress, setDownloadProgress] = useState(0);
55-
const [instance, setInstance] = useState<InstanceSegmentationModule<
56-
InstanceModelNameOf<C>
57-
> | null>(null);
47+
}: InstanceSegmentationProps<C>): InstanceSegmentationType<
48+
InstanceSegmentationLabels<C['modelName']>
49+
> => {
50+
const { error, isReady, isGenerating, downloadProgress, runForward } =
51+
useModuleFactory<InstanceSegmentationModule<InstanceModelNameOf<C>>, C>({
52+
factory: InstanceSegmentationModule.fromModelName,
53+
config: model,
54+
preventLoad,
55+
});
5856

59-
useEffect(() => {
60-
if (preventLoad) return;
61-
62-
let isMounted = true;
63-
let currentInstance: InstanceSegmentationModule<
64-
InstanceModelNameOf<C>
65-
> | null = null;
66-
67-
(async () => {
68-
setDownloadProgress(0);
69-
setError(null);
70-
setIsReady(false);
71-
try {
72-
currentInstance = await InstanceSegmentationModule.fromModelName(
73-
model,
74-
(progress) => {
75-
if (isMounted) setDownloadProgress(progress);
76-
}
77-
);
78-
if (isMounted) {
79-
setInstance(currentInstance);
80-
setIsReady(true);
81-
}
82-
} catch (err) {
83-
if (isMounted) setError(parseUnknownError(err));
84-
}
85-
})();
86-
87-
return () => {
88-
isMounted = false;
89-
currentInstance?.delete();
90-
};
91-
92-
// eslint-disable-next-line react-hooks/exhaustive-deps
93-
}, [model.modelName, model.modelSource, preventLoad]);
94-
95-
const forward = async (
96-
imageSource: string,
97-
options?: InstanceSegmentationOptions<LabelEnum>
98-
) => {
99-
if (!isReady || !instance) {
100-
throw new RnExecutorchError(
101-
RnExecutorchErrorCode.ModuleNotLoaded,
102-
'The model is currently not loaded. Please load the model before calling forward().'
103-
);
104-
}
105-
if (isGenerating) {
106-
throw new RnExecutorchError(
107-
RnExecutorchErrorCode.ModelGenerating,
108-
'The model is currently generating. Please wait until previous model run is complete.'
109-
);
110-
}
111-
try {
112-
setIsGenerating(true);
113-
const result = await instance.forward(imageSource, options);
114-
return result as any;
115-
} catch (err) {
116-
setError(parseUnknownError(err));
117-
throw err;
118-
} finally {
119-
setIsGenerating(false);
120-
}
121-
};
57+
const forward: InstanceSegmentationType<
58+
InstanceSegmentationLabels<C['modelName']>
59+
>['forward'] = (imageSource, options) =>
60+
runForward((instance) => instance.forward(imageSource, options) as any);
12261

12362
return {
12463
error,

packages/react-native-executorch/src/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ declare global {
4646
source: string,
4747
normMean: number[] | [],
4848
normStd: number[] | [],
49-
applyNMS: boolean
49+
applyNMS: boolean,
50+
labelNames: string[]
5051
) => any;
5152
var loadClassification: (source: string) => any;
5253
var loadObjectDetection: (

packages/react-native-executorch/src/modules/computer_vision/InstanceSegmentationModule.ts

Lines changed: 40 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import { ResourceFetcher } from '../../utils/ResourceFetcher';
21
import { ResourceSource, LabelEnum } from '../../types/common';
32
import {
43
InstanceSegmentationModelSources,
@@ -11,18 +10,33 @@ import {
1110
import { CocoLabel } from '../../types/objectDetection';
1211
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
1312
import { RnExecutorchError } from '../../errors/errorUtils';
14-
import { BaseModule } from '../BaseModule';
13+
import { BaseLabeledModule, fetchModelPath } from '../BaseLabeledModule';
1514

16-
const YOLO_SEG_CONFIG = {
15+
const YOLO_SEG_CONFIG: InstanceSegmentationConfig<typeof CocoLabel> = {
1716
labelMap: CocoLabel,
18-
availableInputSizes: [384, 416, 512, 640, 1024] as const,
19-
defaultInputSize: 416,
17+
availableInputSizes: [384, 512, 640] as const,
18+
defaultInputSize: 384,
19+
defaultConfidenceThreshold: 0.5,
20+
defaultIouThreshold: 0.5,
2021
postprocessorConfig: {
21-
defaultConfidenceThreshold: 0.5,
22-
defaultIouThreshold: 0.5,
2322
applyNMS: true,
2423
},
25-
} as const;
24+
};
25+
26+
/**
27+
* Builds an ordered label name array from a label map, indexed by class ID.
28+
* Index i corresponds to class index i produced by the model.
29+
*/
30+
function buildLabelNames(labelMap: LabelEnum): string[] {
31+
const allLabelNames: string[] = [];
32+
for (const [name, value] of Object.entries(labelMap)) {
33+
if (typeof value === 'number') allLabelNames[value] = name;
34+
}
35+
for (let i = 0; i < allLabelNames.length; i++) {
36+
if (allLabelNames[i] == null) allLabelNames[i] = '';
37+
}
38+
return allLabelNames;
39+
}
2640

2741
const ModelConfigs = {
2842
'yolo26n-seg': YOLO_SEG_CONFIG,
@@ -84,24 +98,18 @@ type ResolveLabels<T extends InstanceSegmentationModelName | LabelEnum> =
8498
*/
8599
export class InstanceSegmentationModule<
86100
T extends InstanceSegmentationModelName | LabelEnum,
87-
> extends BaseModule {
88-
private labelMap: ResolveLabels<T>;
101+
> extends BaseLabeledModule<ResolveLabels<T>> {
89102
private modelConfig: InstanceSegmentationConfig<LabelEnum>;
90103

91104
private constructor(
92105
labelMap: ResolveLabels<T>,
93106
modelConfig: InstanceSegmentationConfig<LabelEnum>,
94107
nativeModule: unknown
95108
) {
96-
super();
97-
this.labelMap = labelMap;
109+
super(labelMap, nativeModule);
98110
this.modelConfig = modelConfig;
99-
this.nativeModule = nativeModule;
100111
}
101112

102-
// TODO: figure it out so we can delete this (we need this because of basemodule inheritance)
103-
override async load() {}
104-
105113
/**
106114
* Creates an instance segmentation module for a pre-configured model.
107115
* The config object is discriminated by `modelName` — each model can require different fields.
@@ -125,13 +133,7 @@ export class InstanceSegmentationModule<
125133
const { modelName, modelSource } = config;
126134
const modelConfig = ModelConfigs[modelName as keyof typeof ModelConfigs];
127135

128-
const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource);
129-
if (!paths?.[0]) {
130-
throw new RnExecutorchError(
131-
RnExecutorchErrorCode.DownloadInterrupted,
132-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
133-
);
134-
}
136+
const path = await fetchModelPath(modelSource, onDownloadProgress);
135137

136138
if (typeof global.loadInstanceSegmentation !== 'function') {
137139
throw new RnExecutorchError(
@@ -140,12 +142,12 @@ export class InstanceSegmentationModule<
140142
);
141143
}
142144

143-
// Pass config parameters to native module
144145
const nativeModule = global.loadInstanceSegmentation(
145-
paths[0],
146+
path,
146147
modelConfig.preprocessorConfig?.normMean || [],
147148
modelConfig.preprocessorConfig?.normStd || [],
148-
modelConfig.postprocessorConfig.applyNMS ?? true
149+
modelConfig.postprocessorConfig?.applyNMS ?? true,
150+
buildLabelNames(modelConfig.labelMap)
149151
);
150152

151153
return new InstanceSegmentationModule<InstanceModelNameOf<C>>(
@@ -173,11 +175,9 @@ export class InstanceSegmentationModule<
173175
* labelMap: MyLabels,
174176
* availableInputSizes: [640],
175177
* defaultInputSize: 640,
176-
* postprocessorConfig: {
177-
* defaultConfidenceThreshold: 0.5,
178-
* defaultIouThreshold: 0.45,
179-
* applyNMS: true,
180-
* },
178+
* defaultConfidenceThreshold: 0.5,
179+
* defaultIouThreshold: 0.45,
180+
* postprocessorConfig: { applyNMS: true },
181181
* },
182182
* );
183183
* ```
@@ -187,13 +187,7 @@ export class InstanceSegmentationModule<
187187
config: InstanceSegmentationConfig<L>,
188188
onDownloadProgress: (progress: number) => void = () => {}
189189
): Promise<InstanceSegmentationModule<L>> {
190-
const paths = await ResourceFetcher.fetch(onDownloadProgress, modelSource);
191-
if (!paths?.[0]) {
192-
throw new RnExecutorchError(
193-
RnExecutorchErrorCode.DownloadInterrupted,
194-
'The download has been interrupted. Please retry.'
195-
);
196-
}
190+
const path = await fetchModelPath(modelSource, onDownloadProgress);
197191

198192
if (typeof global.loadInstanceSegmentation !== 'function') {
199193
throw new RnExecutorchError(
@@ -202,12 +196,12 @@ export class InstanceSegmentationModule<
202196
);
203197
}
204198

205-
// Pass config parameters to native module
206199
const nativeModule = global.loadInstanceSegmentation(
207-
paths[0],
200+
path,
208201
config.preprocessorConfig?.normMean || [],
209202
config.preprocessorConfig?.normStd || [],
210-
config.postprocessorConfig.applyNMS ?? true
203+
config.postprocessorConfig?.applyNMS ?? true,
204+
buildLabelNames(config.labelMap)
211205
);
212206

213207
return new InstanceSegmentationModule<L>(
@@ -252,20 +246,16 @@ export class InstanceSegmentationModule<
252246
);
253247
}
254248

255-
// Extract options with defaults from config
256249
const confidenceThreshold =
257250
options?.confidenceThreshold ??
258-
this.modelConfig.postprocessorConfig.defaultConfidenceThreshold ??
259-
0.55;
251+
this.modelConfig.defaultConfidenceThreshold ??
252+
0.5;
260253
const iouThreshold =
261-
options?.iouThreshold ??
262-
this.modelConfig.postprocessorConfig.defaultIouThreshold ??
263-
0.55;
254+
options?.iouThreshold ?? this.modelConfig.defaultIouThreshold ?? 0.5;
264255
const maxInstances = options?.maxInstances ?? 100;
265256
const returnMaskAtOriginalResolution =
266257
options?.returnMaskAtOriginalResolution ?? true;
267258

268-
// Get inputSize from options or use default
269259
const inputSize = options?.inputSize ?? this.modelConfig.defaultInputSize;
270260

271261
if (inputSize === undefined) {
@@ -275,7 +265,6 @@ export class InstanceSegmentationModule<
275265
);
276266
}
277267

278-
// Validate inputSize against available sizes
279268
if (
280269
this.modelConfig.availableInputSizes &&
281270
!this.modelConfig.availableInputSizes.includes(
@@ -288,7 +277,6 @@ export class InstanceSegmentationModule<
288277
);
289278
}
290279

291-
// Convert classesOfInterest labels to indices
292280
const classIndices = options?.classesOfInterest
293281
? options.classesOfInterest.map((label) => {
294282
const labelStr = String(label);
@@ -297,39 +285,14 @@ export class InstanceSegmentationModule<
297285
})
298286
: [];
299287

300-
// Measure inference time
301-
const startTime = performance.now();
302-
const nativeResult = await this.nativeModule.generate(
288+
return await this.nativeModule.generate(
303289
imageSource,
304290
confidenceThreshold,
305291
iouThreshold,
306292
maxInstances,
307293
classIndices,
308294
returnMaskAtOriginalResolution,
309-
inputSize // Pass inputSize as number instead of methodName as string
295+
inputSize
310296
);
311-
const endTime = performance.now();
312-
const inferenceTime = endTime - startTime;
313-
314-
console.log(
315-
`[Instance Segmentation] Inference completed in ${inferenceTime.toFixed(2)}ms | Input size: ${inputSize}x${inputSize} | Detected: ${nativeResult.length} instances`
316-
);
317-
318-
// Convert label indices back to label names
319-
// YOLO outputs 0-indexed class IDs, but COCO labels are 1-indexed, so add 1
320-
const reverseLabelMap = Object.entries(
321-
this.labelMap as Record<string, number>
322-
).reduce(
323-
(acc, [key, value]) => {
324-
acc[value as number] = key;
325-
return acc;
326-
},
327-
{} as Record<number, string>
328-
);
329-
330-
return nativeResult.map((instance: any) => ({
331-
...instance,
332-
label: reverseLabelMap[instance.label + 1] || `UNKNOWN_${instance.label}`,
333-
})) as SegmentedInstance<ResolveLabels<T>>[];
334297
}
335298
}

0 commit comments

Comments
 (0)