Skip to content

Commit f9d9449

Browse files
committed
Update functions to throw errors if adapters not initialized
1 parent 1407899 commit f9d9449

File tree

8 files changed

+153
-97
lines changed

8 files changed

+153
-97
lines changed

packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { ResourceSource } from '../../types/common';
33
import { BaseModule } from '../BaseModule';
44
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
55
import { RnExecutorchError } from '../../errors/errorUtils';
6+
import { Logger } from '../../common/Logger';
67

78
/**
89
* Module for image classification tasks.
@@ -21,17 +22,24 @@ export class ClassificationModule extends BaseModule {
2122
model: { modelSource: ResourceSource },
2223
onDownloadProgressCallback: (progress: number) => void = () => {}
2324
): Promise<void> {
24-
const paths = await ResourceFetcher.fetch(
25-
onDownloadProgressCallback,
26-
model.modelSource
27-
);
28-
if (paths === null || paths.length < 1) {
29-
throw new RnExecutorchError(
30-
RnExecutorchErrorCode.DownloadInterrupted,
31-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
25+
try {
26+
const paths = await ResourceFetcher.fetch(
27+
onDownloadProgressCallback,
28+
model.modelSource
3229
);
30+
31+
if (paths === null || paths.length < 1) {
32+
throw new RnExecutorchError(
33+
RnExecutorchErrorCode.DownloadInterrupted,
34+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
35+
);
36+
}
37+
38+
this.nativeModule = global.loadClassification(paths[0] || '');
39+
} catch (error) {
40+
Logger.error('Load failed:', error);
41+
throw error;
3342
}
34-
this.nativeModule = global.loadClassification(paths[0] || '');
3543
}
3644

3745
/**

packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { ResourceSource } from '../../types/common';
33
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
44
import { RnExecutorchError } from '../../errors/errorUtils';
55
import { BaseModule } from '../BaseModule';
6+
import { Logger } from '../../common/Logger';
67

78
/**
89
* Module for generating image embeddings from input images.
@@ -20,17 +21,24 @@ export class ImageEmbeddingsModule extends BaseModule {
2021
model: { modelSource: ResourceSource },
2122
onDownloadProgressCallback: (progress: number) => void = () => {}
2223
): Promise<void> {
23-
const paths = await ResourceFetcher.fetch(
24-
onDownloadProgressCallback,
25-
model.modelSource
26-
);
27-
if (paths === null || paths.length < 1) {
28-
throw new RnExecutorchError(
29-
RnExecutorchErrorCode.DownloadInterrupted,
30-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
24+
try {
25+
const paths = await ResourceFetcher.fetch(
26+
onDownloadProgressCallback,
27+
model.modelSource
3128
);
29+
30+
if (paths === null || paths.length < 1) {
31+
throw new RnExecutorchError(
32+
RnExecutorchErrorCode.DownloadInterrupted,
33+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
34+
);
35+
}
36+
37+
this.nativeModule = global.loadClassification(paths[0] || '');
38+
} catch (error) {
39+
Logger.error('Load failed:', error);
40+
throw error;
3241
}
33-
this.nativeModule = global.loadImageEmbeddings(paths[0] || '');
3442
}
3543

3644
/**

packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { DeeplabLabel } from '../../types/imageSegmentation';
44
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
55
import { RnExecutorchError } from '../../errors/errorUtils';
66
import { BaseModule } from '../BaseModule';
7+
import { Logger } from '../../common/Logger';
78

89
/**
910
* Module for image segmentation tasks.
@@ -22,17 +23,24 @@ export class ImageSegmentationModule extends BaseModule {
2223
model: { modelSource: ResourceSource },
2324
onDownloadProgressCallback: (progress: number) => void = () => {}
2425
): Promise<void> {
25-
const paths = await ResourceFetcher.fetch(
26-
onDownloadProgressCallback,
27-
model.modelSource
28-
);
29-
if (paths === null || paths.length < 1) {
30-
throw new RnExecutorchError(
31-
RnExecutorchErrorCode.DownloadInterrupted,
32-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
26+
try {
27+
const paths = await ResourceFetcher.fetch(
28+
onDownloadProgressCallback,
29+
model.modelSource
3330
);
31+
32+
if (paths === null || paths.length < 1) {
33+
throw new RnExecutorchError(
34+
RnExecutorchErrorCode.DownloadInterrupted,
35+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
36+
);
37+
}
38+
39+
this.nativeModule = global.loadClassification(paths[0] || '');
40+
} catch (error) {
41+
Logger.error('Load failed:', error);
42+
throw error;
3443
}
35-
this.nativeModule = global.loadImageSegmentation(paths[0] || '');
3644
}
3745

3846
/**

packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,17 @@ export class OCRModule {
3030
},
3131
onDownloadProgressCallback: (progress: number) => void = () => {}
3232
) {
33-
await this.controller.load(
34-
model.detectorSource,
35-
model.recognizerSource,
36-
model.language,
37-
onDownloadProgressCallback
38-
);
33+
try {
34+
await this.controller.load(
35+
model.detectorSource,
36+
model.recognizerSource,
37+
model.language,
38+
onDownloadProgressCallback
39+
);
40+
} catch (error) {
41+
console.error('Load Failed:', error);
42+
throw error;
43+
}
3944
}
4045

4146
/**

packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { Detection } from '../../types/objectDetection';
44
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
55
import { RnExecutorchError } from '../../errors/errorUtils';
66
import { BaseModule } from '../BaseModule';
7+
import { Logger } from '../../common/Logger';
78

89
/**
910
* Module for object detection tasks.
@@ -22,17 +23,24 @@ export class ObjectDetectionModule extends BaseModule {
2223
model: { modelSource: ResourceSource },
2324
onDownloadProgressCallback: (progress: number) => void = () => {}
2425
): Promise<void> {
25-
const paths = await ResourceFetcher.fetch(
26-
onDownloadProgressCallback,
27-
model.modelSource
28-
);
29-
if (paths === null || paths.length < 1) {
30-
throw new RnExecutorchError(
31-
RnExecutorchErrorCode.DownloadInterrupted,
32-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
26+
try {
27+
const paths = await ResourceFetcher.fetch(
28+
onDownloadProgressCallback,
29+
model.modelSource
3330
);
31+
32+
if (paths === null || paths.length < 1) {
33+
throw new RnExecutorchError(
34+
RnExecutorchErrorCode.DownloadInterrupted,
35+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
36+
);
37+
}
38+
39+
this.nativeModule = global.loadClassification(paths[0] || '');
40+
} catch (error) {
41+
Logger.error('Load failed:', error);
42+
throw error;
3443
}
35-
this.nativeModule = global.loadObjectDetection(paths[0] || '');
3644
}
3745

3846
/**

packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { ResourceSource } from '../../types/common';
33
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
44
import { RnExecutorchError } from '../../errors/errorUtils';
55
import { BaseModule } from '../BaseModule';
6+
import { Logger } from '../../common/Logger';
67

78
/**
89
* Module for style transfer tasks.
@@ -21,17 +22,24 @@ export class StyleTransferModule extends BaseModule {
2122
model: { modelSource: ResourceSource },
2223
onDownloadProgressCallback: (progress: number) => void = () => {}
2324
): Promise<void> {
24-
const paths = await ResourceFetcher.fetch(
25-
onDownloadProgressCallback,
26-
model.modelSource
27-
);
28-
if (paths === null || paths.length < 1) {
29-
throw new RnExecutorchError(
30-
RnExecutorchErrorCode.DownloadInterrupted,
31-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
25+
try {
26+
const paths = await ResourceFetcher.fetch(
27+
onDownloadProgressCallback,
28+
model.modelSource
3229
);
30+
31+
if (paths === null || paths.length < 1) {
32+
throw new RnExecutorchError(
33+
RnExecutorchErrorCode.DownloadInterrupted,
34+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
35+
);
36+
}
37+
38+
this.nativeModule = global.loadClassification(paths[0] || '');
39+
} catch (error) {
40+
Logger.error('Load failed:', error);
41+
throw error;
3342
}
34-
this.nativeModule = global.loadStyleTransfer(paths[0] || '');
3543
}
3644

3745
/**

packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,49 +42,54 @@ export class TextToImageModule extends BaseModule {
4242
},
4343
onDownloadProgressCallback: (progress: number) => void = () => {}
4444
): Promise<void> {
45-
const results = await ResourceFetcher.fetch(
46-
onDownloadProgressCallback,
47-
model.tokenizerSource,
48-
model.schedulerSource,
49-
model.encoderSource,
50-
model.unetSource,
51-
model.decoderSource
52-
);
53-
if (!results) {
54-
throw new RnExecutorchError(
55-
RnExecutorchErrorCode.DownloadInterrupted,
56-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
45+
try {
46+
const results = await ResourceFetcher.fetch(
47+
onDownloadProgressCallback,
48+
model.tokenizerSource,
49+
model.schedulerSource,
50+
model.encoderSource,
51+
model.unetSource,
52+
model.decoderSource
5753
);
58-
}
59-
const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] =
60-
results;
54+
if (!results) {
55+
throw new RnExecutorchError(
56+
RnExecutorchErrorCode.DownloadInterrupted,
57+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
58+
);
59+
}
60+
const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] =
61+
results;
6162

62-
if (
63-
!tokenizerPath ||
64-
!schedulerPath ||
65-
!encoderPath ||
66-
!unetPath ||
67-
!decoderPath
68-
) {
69-
throw new RnExecutorchError(
70-
RnExecutorchErrorCode.DownloadInterrupted,
71-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
72-
);
73-
}
63+
if (
64+
!tokenizerPath ||
65+
!schedulerPath ||
66+
!encoderPath ||
67+
!unetPath ||
68+
!decoderPath
69+
) {
70+
throw new RnExecutorchError(
71+
RnExecutorchErrorCode.DownloadInterrupted,
72+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
73+
);
74+
}
7475

75-
const response = await fetch('file://' + schedulerPath);
76-
const schedulerConfig = await response.json();
76+
const response = await fetch('file://' + schedulerPath);
77+
const schedulerConfig = await response.json();
7778

78-
this.nativeModule = global.loadTextToImage(
79-
tokenizerPath,
80-
encoderPath,
81-
unetPath,
82-
decoderPath,
83-
schedulerConfig.beta_start,
84-
schedulerConfig.beta_end,
85-
schedulerConfig.num_train_timesteps,
86-
schedulerConfig.steps_offset
87-
);
79+
this.nativeModule = global.loadTextToImage(
80+
tokenizerPath,
81+
encoderPath,
82+
unetPath,
83+
decoderPath,
84+
schedulerConfig.beta_start,
85+
schedulerConfig.beta_end,
86+
schedulerConfig.num_train_timesteps,
87+
schedulerConfig.steps_offset
88+
);
89+
} catch (error) {
90+
console.error('Load Failed:', error);
91+
throw error;
92+
}
8893
}
8994

9095
/**

packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { Logger } from '../../common/Logger';
12
import { VerticalOCRController } from '../../controllers/VerticalOCRController';
23
import { ResourceSource } from '../../types/common';
34
import { OCRDetection, OCRLanguage } from '../../types/ocr';
@@ -32,13 +33,18 @@ export class VerticalOCRModule {
3233
independentCharacters: boolean,
3334
onDownloadProgressCallback: (progress: number) => void = () => {}
3435
) {
35-
await this.controller.load(
36-
model.detectorSource,
37-
model.recognizerSource,
38-
model.language,
39-
independentCharacters,
40-
onDownloadProgressCallback
41-
);
36+
try {
37+
await this.controller.load(
38+
model.detectorSource,
39+
model.recognizerSource,
40+
model.language,
41+
independentCharacters,
42+
onDownloadProgressCallback
43+
);
44+
} catch (error) {
45+
Logger.error('Load failed:', error);
46+
throw error;
47+
}
4248
}
4349

4450
/**

0 commit comments

Comments
 (0)