1- import { ResourceFetcher } from '../../utils/ResourceFetcher' ;
21import { ResourceSource , LabelEnum } from '../../types/common' ;
32import {
43 InstanceSegmentationModelSources ,
@@ -11,18 +10,33 @@ import {
1110import { CocoLabel } from '../../types/objectDetection' ;
1211import { RnExecutorchErrorCode } from '../../errors/ErrorCodes' ;
1312import { RnExecutorchError } from '../../errors/errorUtils' ;
14- import { BaseModule } from '../BaseModule ' ;
13+ import { BaseLabeledModule , fetchModelPath } from '../BaseLabeledModule ' ;
1514
16- const YOLO_SEG_CONFIG = {
15+ const YOLO_SEG_CONFIG : InstanceSegmentationConfig < typeof CocoLabel > = {
1716 labelMap : CocoLabel ,
18- availableInputSizes : [ 384 , 416 , 512 , 640 , 1024 ] as const ,
19- defaultInputSize : 416 ,
17+ availableInputSizes : [ 384 , 512 , 640 ] as const ,
18+ defaultInputSize : 384 ,
19+ defaultConfidenceThreshold : 0.5 ,
20+ defaultIouThreshold : 0.5 ,
2021 postprocessorConfig : {
21- defaultConfidenceThreshold : 0.5 ,
22- defaultIouThreshold : 0.5 ,
2322 applyNMS : true ,
2423 } ,
25- } as const ;
24+ } ;
25+
26+ /**
27+ * Builds an ordered label name array from a label map, indexed by class ID.
28+ * Index i corresponds to class index i produced by the model.
29+ */
30+ function buildLabelNames ( labelMap : LabelEnum ) : string [ ] {
31+ const allLabelNames : string [ ] = [ ] ;
32+ for ( const [ name , value ] of Object . entries ( labelMap ) ) {
33+ if ( typeof value === 'number' ) allLabelNames [ value ] = name ;
34+ }
35+ for ( let i = 0 ; i < allLabelNames . length ; i ++ ) {
36+ if ( allLabelNames [ i ] == null ) allLabelNames [ i ] = '' ;
37+ }
38+ return allLabelNames ;
39+ }
2640
2741const ModelConfigs = {
2842 'yolo26n-seg' : YOLO_SEG_CONFIG ,
@@ -84,24 +98,18 @@ type ResolveLabels<T extends InstanceSegmentationModelName | LabelEnum> =
8498 */
8599export class InstanceSegmentationModule <
86100 T extends InstanceSegmentationModelName | LabelEnum ,
87- > extends BaseModule {
88- private labelMap : ResolveLabels < T > ;
101+ > extends BaseLabeledModule < ResolveLabels < T > > {
89102 private modelConfig : InstanceSegmentationConfig < LabelEnum > ;
90103
91104 private constructor (
92105 labelMap : ResolveLabels < T > ,
93106 modelConfig : InstanceSegmentationConfig < LabelEnum > ,
94107 nativeModule : unknown
95108 ) {
96- super ( ) ;
97- this . labelMap = labelMap ;
109+ super ( labelMap , nativeModule ) ;
98110 this . modelConfig = modelConfig ;
99- this . nativeModule = nativeModule ;
100111 }
101112
102- // TODO: figure it out so we can delete this (we need this because of basemodule inheritance)
103- override async load ( ) { }
104-
105113 /**
106114 * Creates an instance segmentation module for a pre-configured model.
107115 * The config object is discriminated by `modelName` — each model can require different fields.
@@ -125,13 +133,7 @@ export class InstanceSegmentationModule<
125133 const { modelName, modelSource } = config ;
126134 const modelConfig = ModelConfigs [ modelName as keyof typeof ModelConfigs ] ;
127135
128- const paths = await ResourceFetcher . fetch ( onDownloadProgress , modelSource ) ;
129- if ( ! paths ?. [ 0 ] ) {
130- throw new RnExecutorchError (
131- RnExecutorchErrorCode . DownloadInterrupted ,
132- 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
133- ) ;
134- }
136+ const path = await fetchModelPath ( modelSource , onDownloadProgress ) ;
135137
136138 if ( typeof global . loadInstanceSegmentation !== 'function' ) {
137139 throw new RnExecutorchError (
@@ -140,12 +142,12 @@ export class InstanceSegmentationModule<
140142 ) ;
141143 }
142144
143- // Pass config parameters to native module
144145 const nativeModule = global . loadInstanceSegmentation (
145- paths [ 0 ] ,
146+ path ,
146147 modelConfig . preprocessorConfig ?. normMean || [ ] ,
147148 modelConfig . preprocessorConfig ?. normStd || [ ] ,
148- modelConfig . postprocessorConfig . applyNMS ?? true
149+ modelConfig . postprocessorConfig ?. applyNMS ?? true ,
150+ buildLabelNames ( modelConfig . labelMap )
149151 ) ;
150152
151153 return new InstanceSegmentationModule < InstanceModelNameOf < C > > (
@@ -173,11 +175,9 @@ export class InstanceSegmentationModule<
173175 * labelMap: MyLabels,
174176 * availableInputSizes: [640],
175177 * defaultInputSize: 640,
176- * postprocessorConfig: {
177- * defaultConfidenceThreshold: 0.5,
178- * defaultIouThreshold: 0.45,
179- * applyNMS: true,
180- * },
178+ * defaultConfidenceThreshold: 0.5,
179+ * defaultIouThreshold: 0.45,
180+ * postprocessorConfig: { applyNMS: true },
181181 * },
182182 * );
183183 * ```
@@ -187,13 +187,7 @@ export class InstanceSegmentationModule<
187187 config : InstanceSegmentationConfig < L > ,
188188 onDownloadProgress : ( progress : number ) => void = ( ) => { }
189189 ) : Promise < InstanceSegmentationModule < L > > {
190- const paths = await ResourceFetcher . fetch ( onDownloadProgress , modelSource ) ;
191- if ( ! paths ?. [ 0 ] ) {
192- throw new RnExecutorchError (
193- RnExecutorchErrorCode . DownloadInterrupted ,
194- 'The download has been interrupted. Please retry.'
195- ) ;
196- }
190+ const path = await fetchModelPath ( modelSource , onDownloadProgress ) ;
197191
198192 if ( typeof global . loadInstanceSegmentation !== 'function' ) {
199193 throw new RnExecutorchError (
@@ -202,12 +196,12 @@ export class InstanceSegmentationModule<
202196 ) ;
203197 }
204198
205- // Pass config parameters to native module
206199 const nativeModule = global . loadInstanceSegmentation (
207- paths [ 0 ] ,
200+ path ,
208201 config . preprocessorConfig ?. normMean || [ ] ,
209202 config . preprocessorConfig ?. normStd || [ ] ,
210- config . postprocessorConfig . applyNMS ?? true
203+ config . postprocessorConfig ?. applyNMS ?? true ,
204+ buildLabelNames ( config . labelMap )
211205 ) ;
212206
213207 return new InstanceSegmentationModule < L > (
@@ -252,20 +246,16 @@ export class InstanceSegmentationModule<
252246 ) ;
253247 }
254248
255- // Extract options with defaults from config
256249 const confidenceThreshold =
257250 options ?. confidenceThreshold ??
258- this . modelConfig . postprocessorConfig . defaultConfidenceThreshold ??
259- 0.55 ;
251+ this . modelConfig . defaultConfidenceThreshold ??
252+ 0.5 ;
260253 const iouThreshold =
261- options ?. iouThreshold ??
262- this . modelConfig . postprocessorConfig . defaultIouThreshold ??
263- 0.55 ;
254+ options ?. iouThreshold ?? this . modelConfig . defaultIouThreshold ?? 0.5 ;
264255 const maxInstances = options ?. maxInstances ?? 100 ;
265256 const returnMaskAtOriginalResolution =
266257 options ?. returnMaskAtOriginalResolution ?? true ;
267258
268- // Get inputSize from options or use default
269259 const inputSize = options ?. inputSize ?? this . modelConfig . defaultInputSize ;
270260
271261 if ( inputSize === undefined ) {
@@ -275,7 +265,6 @@ export class InstanceSegmentationModule<
275265 ) ;
276266 }
277267
278- // Validate inputSize against available sizes
279268 if (
280269 this . modelConfig . availableInputSizes &&
281270 ! this . modelConfig . availableInputSizes . includes (
@@ -288,7 +277,6 @@ export class InstanceSegmentationModule<
288277 ) ;
289278 }
290279
291- // Convert classesOfInterest labels to indices
292280 const classIndices = options ?. classesOfInterest
293281 ? options . classesOfInterest . map ( ( label ) => {
294282 const labelStr = String ( label ) ;
@@ -297,39 +285,14 @@ export class InstanceSegmentationModule<
297285 } )
298286 : [ ] ;
299287
300- // Measure inference time
301- const startTime = performance . now ( ) ;
302- const nativeResult = await this . nativeModule . generate (
288+ return await this . nativeModule . generate (
303289 imageSource ,
304290 confidenceThreshold ,
305291 iouThreshold ,
306292 maxInstances ,
307293 classIndices ,
308294 returnMaskAtOriginalResolution ,
309- inputSize // Pass inputSize as number instead of methodName as string
295+ inputSize
310296 ) ;
311- const endTime = performance . now ( ) ;
312- const inferenceTime = endTime - startTime ;
313-
314- console . log (
315- `[Instance Segmentation] Inference completed in ${ inferenceTime . toFixed ( 2 ) } ms | Input size: ${ inputSize } x${ inputSize } | Detected: ${ nativeResult . length } instances`
316- ) ;
317-
318- // Convert label indices back to label names
319- // YOLO outputs 0-indexed class IDs, but COCO labels are 1-indexed, so add 1
320- const reverseLabelMap = Object . entries (
321- this . labelMap as Record < string , number >
322- ) . reduce (
323- ( acc , [ key , value ] ) => {
324- acc [ value as number ] = key ;
325- return acc ;
326- } ,
327- { } as Record < number , string >
328- ) ;
329-
330- return nativeResult . map ( ( instance : any ) => ( {
331- ...instance ,
332- label : reverseLabelMap [ instance . label + 1 ] || `UNKNOWN_${ instance . label } ` ,
333- } ) ) as SegmentedInstance < ResolveLabels < T > > [ ] ;
334297 }
335298}
0 commit comments