diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index 3ce52c409f..bbc609ab24 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -43,7 +43,6 @@ export default function ObjectDetectionScreen() { if (imageUri) { try { const output = await ssdLite.forward(imageUri); - console.log(output); setResults(output); } catch (e) { console.error(e); diff --git a/apps/computer-vision/app/ocr/index.tsx b/apps/computer-vision/app/ocr/index.tsx index b2ba8d04dc..c56309e8a0 100644 --- a/apps/computer-vision/app/ocr/index.tsx +++ b/apps/computer-vision/app/ocr/index.tsx @@ -38,7 +38,6 @@ export default function OCRScreen() { try { const output = await model.forward(imageUri); setResults(output); - console.log(output); } catch (e) { console.error(e); } @@ -78,8 +77,8 @@ export default function OCRScreen() { Results - {results.map(({ text, score }) => ( - + {results.map(({ text, score }, index) => ( + {text} {score.toFixed(3)} diff --git a/apps/computer-vision/app/ocr_vertical/index.tsx b/apps/computer-vision/app/ocr_vertical/index.tsx index 040c709c63..28e73eac5c 100644 --- a/apps/computer-vision/app/ocr_vertical/index.tsx +++ b/apps/computer-vision/app/ocr_vertical/index.tsx @@ -40,7 +40,6 @@ export default function VerticalOCRScree() { try { const output = await model.forward(imageUri); setResults(output); - console.log(output); } catch (e) { console.error(e); } @@ -80,8 +79,8 @@ export default function VerticalOCRScree() { Results - {results.map(({ text, score }) => ( - + {results.map(({ text, score }, index) => ( + {text} {score.toFixed(3)} diff --git a/apps/computer-vision/ios/Podfile.lock b/apps/computer-vision/ios/Podfile.lock index 730d430f63..a90e6e1e19 100644 --- a/apps/computer-vision/ios/Podfile.lock +++ b/apps/computer-vision/ios/Podfile.lock @@ -2454,7 +2454,7 @@ SPEC CHECKSUMS: React-logger: 8edfcedc100544791cd82692ca5a574240a16219 React-Mapbuffer: c3f4b608e4a59dd2f6a416ef4d47a14400194468 React-microtasksnativemodule: 054f34e9b82f02bd40f09cebd4083828b5b2beb6 - react-native-executorch: 98a2d5c0fc2290d473db87f2d6f3bf9dc7b77ab1 + react-native-executorch: d06ae11e5411f0cb798316c4e69cf7d8678da297 react-native-image-picker: 8a3f16000e794f5381a7fe47bb48fd8d06741e47 react-native-safe-area-context: 562163222d999b79a51577eda2ea8ad2c32b4d06 react-native-skia: b6cb66e99a953dae6880348c92cfb20a76d90b4f diff --git a/docs/docs/02-hooks/02-computer-vision/useOCR.md b/docs/docs/02-hooks/02-computer-vision/useOCR.md index ad25998d76..1384bed7a9 100644 --- a/docs/docs/02-hooks/02-computer-vision/useOCR.md +++ b/docs/docs/02-hooks/02-computer-vision/useOCR.md @@ -301,19 +301,34 @@ You need to make sure the recognizer models you pass in `recognizerSources` matc | Model | Android (XNNPACK) [MB] | iOS (XNNPACK) [MB] | | -------------------------------------------------------------------------------------------- | :--------------------: | :----------------: | -| Detector (CRAFT_800) + Recognizer (CRNN_512) + Recognizer (CRNN_256) + Recognizer (CRNN_128) | 2100 | 1782 | +| Detector (CRAFT_800) + Recognizer (CRNN_512) + Recognizer (CRNN_256) + Recognizer (CRNN_128) | 1600 | 1700 | ### Inference time +**Image Used for Benchmarking:** + +| ![Alt text](../../../static/img/harvard.png) | ![Alt text](../../../static/img/harvard-boxes.png) | +| -------------------------------------------- | -------------------------------------------------- | +| Original Image | Image with detected Text Boxes | + :::warning warning Times presented in the tables are measured as consecutive runs of the model. Initial run times may be up to 2x longer due to model loading and initialization. ::: -| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro Max (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) [ms] | Samsung Galaxy S24 (XNNPACK) [ms] | Samsung Galaxy S21 (XNNPACK) [ms] | -| --------------------- | :--------------------------: | :------------------------------: | :------------------------: | :-------------------------------: | :-------------------------------: | -| Detector (CRAFT_800) | 2099 | 2227 | ❌ | 2245 | 7108 | -| Recognizer (CRNN_512) | 70 | 252 | ❌ | 54 | 151 | -| Recognizer (CRNN_256) | 39 | 123 | ❌ | 24 | 78 | -| Recognizer (CRNN_128) | 17 | 83 | ❌ | 14 | 39 | +**Time measurements:** + +| Metric | iPhone 14 Pro Max
[ms] | iPhone 16 Pro
[ms] | iPhone SE 3 | Samsung Galaxy S24
[ms] | OnePlus 12
[ms] | +| ------------------------- | ----------------------------- | ------------------------- | ----------- | ------------------------------ | ---------------------- | +| **Total Inference Time** | 4330 | 2537 | ❌ | 6648 | 5993 | +| **Detector (CRAFT_800)** | 1945 | 1809 | ❌ | 2080 | 1961 | +| **Recognizer (CRNN_512)** | | | | | | +| ├─ Average Time | 273 | 76 | ❌ | 289 | 252 | +| ├─ Total Time (3 runs) | 820 | 229 | ❌ | 867 | 756 | +| **Recognizer (CRNN_256)** | | | | | | +| ├─ Average Time | 137 | 39 | ❌ | 260 | 229 | +| ├─ Total Time (7 runs) | 958 | 271 | ❌ | 1818 | 1601 | +| **Recognizer (CRNN_128)** | | | | | | +| ├─ Average Time | 68 | 18 | ❌ | 239 | 214 | +| ├─ Total Time (7 runs) | 478 | 124 | ❌ | 1673 | 1498 | ❌ - Insufficient RAM. diff --git a/docs/docs/02-hooks/02-computer-vision/useVerticalOCR.md b/docs/docs/02-hooks/02-computer-vision/useVerticalOCR.md index b449d9f07c..b94eef9db0 100644 --- a/docs/docs/02-hooks/02-computer-vision/useVerticalOCR.md +++ b/docs/docs/02-hooks/02-computer-vision/useVerticalOCR.md @@ -316,20 +316,35 @@ You need to make sure the recognizer models you pass in `recognizerSources` matc | Model | Android (XNNPACK) [MB] | iOS (XNNPACK) [MB] | | -------------------------------------------------------------------- | :--------------------: | :----------------: | -| Detector (CRAFT_1280) + Detector (CRAFT_320) + Recognizer (CRNN_512) | 2770 | 3720 | -| Detector(CRAFT_1280) + Detector(CRAFT_320) + Recognizer (CRNN_64) | 1770 | 2740 | +| Detector (CRAFT_1280) + Detector (CRAFT_320) + Recognizer (CRNN_512) | 2172 | 2214 | +| Detector(CRAFT_1280) + Detector(CRAFT_320) + Recognizer (CRNN_64) | 1774 | 1705 | ### Inference time +**Image Used for Benchmarking:** + +| ![Alt text](../../../static/img/sales-vertical.jpeg) | ![Alt text](../../../static/img/sales-vertical-boxes.png) | +| ---------------------------------------------------- | --------------------------------------------------------- | +| Original Image | Image with detected Text Boxes | + :::warning warning Times presented in the tables are measured as consecutive runs of the model. Initial run times may be up to 2x longer due to model loading and initialization. ::: -| Model | iPhone 16 Pro (XNNPACK) [ms] | iPhone 14 Pro Max (XNNPACK) [ms] | iPhone SE 3 (XNNPACK) [ms] | Samsung Galaxy S24 (XNNPACK) [ms] | Samsung Galaxy S21 (XNNPACK) [ms] | -| --------------------- | :--------------------------: | :------------------------------: | :------------------------: | :-------------------------------: | :-------------------------------: | -| Detector (CRAFT_1280) | 5457 | 5833 | ❌ | 6296 | 14053 | -| Detector (CRAFT_320) | 1351 | 1460 | ❌ | 1485 | 3101 | -| Recognizer (CRNN_512) | 39 | 123 | ❌ | 24 | 78 | -| Recognizer (CRNN_64) | 10 | 33 | ❌ | 7 | 18 | +**Time measurements:** + +| Metric | iPhone 14 Pro Max
[ms] | iPhone 16 Pro
[ms] | iPhone SE 3 | Samsung Galaxy S24
[ms] | OnePlus 12
[ms] | +| -------------------------------------------------------------------------- | ----------------------------- | ------------------------- | ----------- | ------------------------------ | ---------------------- | +| **Total Inference Time** | 9350 / 9620 | 8572 / 8621 | ❌ | 13737 / 10570 | 13436 / 9848 | +| **Detector (CRAFT_1250)** | 4895 | 4756 | ❌ | 5574 | 5016 | +| **Detector (CRAFT_320)** | | | | | | +| ├─ Average Time | 1247 | 1206 | ❌ | 1350 | 1356 | +| ├─ Total Time (3 runs) | 3741 | 3617 | ❌ | 4050 | 4069 | +| **Recognizer (CRNN_64)**
(_With Flag `independentChars == true`_) | | | | | | +| ├─ Average Time | 31 | 9 | ❌ | 195 | 207 | +| ├─ Total Time (21 runs) | 649 | 191 | ❌ | 4092 | 4339 | +| **Recognizer (CRNN_512)**
(_With Flag `independentChars == false`_) | | | | | | +| ├─ Average Time | 306 | 80 | ❌ | 308 | 250 | +| ├─ Total Time (3 runs) | 919 | 240 | ❌ | 925 | 751 | ❌ - Insufficient RAM. diff --git a/docs/docs/03-typescript-api/02-computer-vision/OCRModule.md b/docs/docs/03-typescript-api/02-computer-vision/OCRModule.md index f709ffe1a4..43a812005f 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/OCRModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/OCRModule.md @@ -22,11 +22,11 @@ const detections = await ocrModule.forward(imageUri); ### Methods -| Method | Type | Description | -| -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `load` | `(model: { detectorSource: ResourceSource; recognizerLarge: ResourceSource; recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; language: OCRLanguage }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `detectorSource` is a string that specifies the location of the detector binary, `recognizerLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels, `recognizerMedium` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 256 pixels, `recognizerSmall` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 128 pixels, and `language` is a parameter that specifies the language of the text to be recognized by the OCR. | -| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | -| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. | +| Method | Type | Description | +| --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `load` | `(model: { detectorSource: ResourceSource; recognizerLarge: ResourceSource; recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; language: OCRLanguage }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `detectorSource` is a string that specifies the location of the detector binary, `recognizerLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels, `recognizerMedium` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 256 pixels, `recognizerSmall` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 128 pixels, and `language` is a parameter that specifies the language of the text to be recognized by the OCR. | +| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | +| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. Note that you cannot delete model while it's generating. |
Type definitions diff --git a/docs/docs/03-typescript-api/02-computer-vision/VerticalOCRModule.md b/docs/docs/03-typescript-api/02-computer-vision/VerticalOCRModule.md index bf3b56c7e5..27b4564adb 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/VerticalOCRModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/VerticalOCRModule.md @@ -26,11 +26,11 @@ const detections = await verticalOCRModule.forward(imageUri); ### Methods -| Method | Type | Description | -| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `load` | `(model: { detectorLarge: ResourceSource; detectorNarrow: ResourceSource; recognizerLarge: ResourceSource; recognizerSmall: ResourceSource; language: OCRLanguage }, independentCharacters: boolean, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `detectorLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 1280 pixels, `detectorNarrow` is a string that specifies the location of the detector binary file which accepts input images with a width of 320 pixels, `recognizerLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels, `recognizerSmall` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 64 pixels, and `language` is a parameter that specifies the language of the text to be recognized by the OCR. | -| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | -| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. | +| Method | Type | Description | +| --------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `load` | `(model: { detectorLarge: ResourceSource; detectorNarrow: ResourceSource; recognizerLarge: ResourceSource; recognizerSmall: ResourceSource; language: OCRLanguage }, independentCharacters: boolean, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `detectorLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 1280 pixels, `detectorNarrow` is a string that specifies the location of the detector binary file which accepts input images with a width of 320 pixels, `recognizerLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels, `recognizerSmall` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 64 pixels, and `language` is a parameter that specifies the language of the text to be recognized by the OCR. | +| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | +| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. Note that you cannot delete model while it's generating. |
Type definitions diff --git a/docs/static/img/harvard-boxes.png b/docs/static/img/harvard-boxes.png new file mode 100644 index 0000000000..e53562131a Binary files /dev/null and b/docs/static/img/harvard-boxes.png differ diff --git a/docs/static/img/harvard.png b/docs/static/img/harvard.png new file mode 100644 index 0000000000..e6fb37ff98 Binary files /dev/null and b/docs/static/img/harvard.png differ diff --git a/docs/static/img/sales-vertical-boxes.png b/docs/static/img/sales-vertical-boxes.png new file mode 100644 index 0000000000..26278cc460 Binary files /dev/null and b/docs/static/img/sales-vertical-boxes.png differ diff --git a/docs/static/img/sales-vertical.jpeg b/docs/static/img/sales-vertical.jpeg new file mode 100644 index 0000000000..8d017b0c57 Binary files /dev/null and b/docs/static/img/sales-vertical.jpeg differ diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt deleted file mode 100644 index b679d95dfd..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt +++ /dev/null @@ -1,90 +0,0 @@ -package com.swmansion.rnexecutorch - -import android.util.Log -import com.facebook.react.bridge.Promise -import com.facebook.react.bridge.ReactApplicationContext -import com.swmansion.rnexecutorch.models.ocr.Detector -import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler -import com.swmansion.rnexecutorch.models.ocr.utils.Constants -import com.swmansion.rnexecutorch.utils.ETError -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.android.OpenCVLoader -import org.opencv.imgproc.Imgproc - -class OCR( - reactContext: ReactApplicationContext, -) : NativeOCRSpec(reactContext) { - private lateinit var detector: Detector - private lateinit var recognitionHandler: RecognitionHandler - - companion object { - const val NAME = "OCR" - } - - init { - if (!OpenCVLoader.initLocal()) { - Log.d("rn_executorch", "OpenCV not loaded") - } else { - Log.d("rn_executorch", "OpenCV loaded") - } - } - - override fun loadModule( - detectorSource: String, - recognizerSourceLarge: String, - recognizerSourceMedium: String, - recognizerSourceSmall: String, - symbols: String, - promise: Promise, - ) { - try { - detector = Detector(reactApplicationContext) - detector.loadModel(detectorSource) - - recognitionHandler = - RecognitionHandler( - symbols, - reactApplicationContext, - ) - - recognitionHandler.loadRecognizers( - recognizerSourceLarge, - recognizerSourceMedium, - recognizerSourceSmall, - ) { _, errorRecognizer -> - if (errorRecognizer != null) { - throw Error(errorRecognizer.message!!) - } - - promise.resolve(0) - } - } catch (e: Exception) { - promise.reject(e.message!!, ETError.InvalidModelSource.toString()) - } - } - - override fun forward( - input: String, - promise: Promise, - ) { - try { - val inputImage = ImageProcessor.readImage(input) - val bBoxesList = detector.runModel(inputImage) - val detectorSize = detector.getModelImageSize() - Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY) - val result = - recognitionHandler.recognize( - bBoxesList, - inputImage, - (detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(), - (detectorSize.height * Constants.RECOGNIZER_RATIO).toInt(), - ) - promise.resolve(result) - } catch (e: Exception) { - Log.d("rn_executorch", "Error running model: ${e.message}") - promise.reject(e.message!!, e.message) - } - } - - override fun getName(): String = NAME -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index c4dfc6c5c1..0b15e216a5 100644 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -14,11 +14,7 @@ class RnExecutorchPackage : TurboReactPackage() { name: String, reactContext: ReactApplicationContext, ): NativeModule? = - if (name == OCR.NAME) { - OCR(reactContext) - } else if (name == VerticalOCR.NAME) { - VerticalOCR(reactContext) - } else if (name == ETInstaller.NAME) { + if (name == ETInstaller.NAME) { ETInstaller(reactContext) } else { null @@ -27,28 +23,6 @@ class RnExecutorchPackage : TurboReactPackage() { override fun getReactModuleInfoProvider(): ReactModuleInfoProvider = ReactModuleInfoProvider { val moduleInfos: MutableMap = HashMap() - moduleInfos[OCR.NAME] = - ReactModuleInfo( - OCR.NAME, - OCR.NAME, - false, // canOverrideExistingModule - false, // needsEagerInit - true, // hasConstants - false, // isCxxModule - true, - ) - - moduleInfos[VerticalOCR.NAME] = - ReactModuleInfo( - VerticalOCR.NAME, - VerticalOCR.NAME, - false, // canOverrideExistingModule - false, // needsEagerInit - true, // hasConstants - false, // isCxxModule - true, - ) - moduleInfos[ETInstaller.NAME] = ReactModuleInfo( ETInstaller.NAME, diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt deleted file mode 100644 index 3e36b3edba..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt +++ /dev/null @@ -1,179 +0,0 @@ -package com.swmansion.rnexecutorch - -import android.util.Log -import com.facebook.react.bridge.Arguments -import com.facebook.react.bridge.Promise -import com.facebook.react.bridge.ReactApplicationContext -import com.swmansion.rnexecutorch.models.ocr.Recognizer -import com.swmansion.rnexecutorch.models.ocr.VerticalDetector -import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter -import com.swmansion.rnexecutorch.models.ocr.utils.Constants -import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils -import com.swmansion.rnexecutorch.utils.ETError -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.android.OpenCVLoader -import org.opencv.core.Core -import org.opencv.core.Mat - -class VerticalOCR( - reactContext: ReactApplicationContext, -) : NativeVerticalOCRSpec(reactContext) { - private lateinit var detectorLarge: VerticalDetector - private lateinit var detectorNarrow: VerticalDetector - private lateinit var recognizer: Recognizer - private lateinit var converter: CTCLabelConverter - private var independentCharacters = true - - companion object { - const val NAME = "VerticalOCR" - } - - init { - if (!OpenCVLoader.initLocal()) { - Log.d("rn_executorch", "OpenCV not loaded") - } else { - Log.d("rn_executorch", "OpenCV loaded") - } - } - - override fun loadModule( - detectorLargeSource: String, - detectorNarrowSource: String, - recognizerSource: String, - symbols: String, - independentCharacters: Boolean, - promise: Promise, - ) { - try { - this.independentCharacters = independentCharacters - detectorLarge = VerticalDetector(false, reactApplicationContext) - detectorLarge.loadModel(detectorLargeSource) - detectorNarrow = VerticalDetector(true, reactApplicationContext) - detectorNarrow.loadModel(detectorNarrowSource) - recognizer = Recognizer(reactApplicationContext) - recognizer.loadModel(recognizerSource) - - converter = CTCLabelConverter(symbols) - - promise.resolve(0) - } catch (e: Exception) { - promise.reject(e.message!!, ETError.InvalidModelSource.toString()) - } - } - - override fun forward( - input: String, - promise: Promise, - ) { - try { - val inputImage = ImageProcessor.readImage(input) - val result = detectorLarge.runModel(inputImage) - val largeDetectorSize = detectorLarge.getModelImageSize() - val resizedImage = - ImageProcessor.resizeWithPadding( - inputImage, - largeDetectorSize.width.toInt(), - largeDetectorSize.height.toInt(), - ) - val predictions = Arguments.createArray() - for (box in result) { - val cords = box.bBox - val boxWidth = cords[2].x - cords[0].x - val boxHeight = cords[2].y - cords[0].y - - val boundingBox = RecognizerUtils.extractBoundingBox(cords) - val croppedImage = Mat(resizedImage, boundingBox) - - val paddings = - RecognizerUtils.calculateResizeRatioAndPaddings( - inputImage.width(), - inputImage.height(), - largeDetectorSize.width.toInt(), - largeDetectorSize.height.toInt(), - ) - - var text = "" - var confidenceScore = 0.0 - val boxResult = detectorNarrow.runModel(croppedImage) - val narrowDetectorSize = detectorNarrow.getModelImageSize() - - val croppedCharacters = mutableListOf() - - for (characterBox in boxResult) { - val boxCords = characterBox.bBox - val paddingsBox = - RecognizerUtils.calculateResizeRatioAndPaddings( - boxWidth.toInt(), - boxHeight.toInt(), - narrowDetectorSize.width.toInt(), - narrowDetectorSize.height.toInt(), - ) - - var croppedCharacter = - RecognizerUtils.cropImageWithBoundingBox( - inputImage, - boxCords, - cords, - paddingsBox, - paddings, - ) - - if (this.independentCharacters) { - croppedCharacter = RecognizerUtils.cropSingleCharacter(croppedCharacter) - croppedCharacter = RecognizerUtils.normalizeForRecognizer(croppedCharacter, 0.0, true) - val recognitionResult = recognizer.runModel(croppedCharacter) - val predIndex = recognitionResult.first - val decodedText = converter.decodeGreedy(predIndex, predIndex.size) - text += decodedText[0] - confidenceScore += recognitionResult.second - } else { - croppedCharacters.add(croppedCharacter) - } - } - - if (this.independentCharacters) { - confidenceScore /= boxResult.size - } else { - var mergedCharacters = Mat() - Core.hconcat(croppedCharacters, mergedCharacters) - mergedCharacters = - ImageProcessor.resizeWithPadding( - mergedCharacters, - Constants.LARGE_MODEL_WIDTH, - Constants.MODEL_HEIGHT, - ) - mergedCharacters = RecognizerUtils.normalizeForRecognizer(mergedCharacters, 0.0) - - val recognitionResult = recognizer.runModel(mergedCharacters) - val predIndex = recognitionResult.first - val decodedText = converter.decodeGreedy(predIndex, predIndex.size) - - text = decodedText[0] - confidenceScore = recognitionResult.second - } - - for (bBox in box.bBox) { - bBox.x = - (bBox.x - paddings["left"] as Int) * paddings["resizeRatio"] as Float - bBox.y = - (bBox.y - paddings["top"] as Int) * paddings["resizeRatio"] as Float - } - - val resMap = Arguments.createMap() - - resMap.putString("text", text) - resMap.putArray("bbox", box.toWritableArray()) - resMap.putDouble("score", confidenceScore) - - predictions.pushMap(resMap) - } - - promise.resolve(predictions) - } catch (e: Exception) { - Log.d("rn_executorch", "Error running model: ${e.message}") - promise.reject(e.message!!, e.message) - } - } - - override fun getName(): String = NAME -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt deleted file mode 100644 index 9e010e3472..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt +++ /dev/null @@ -1,54 +0,0 @@ -package com.swmansion.rnexecutorch.models - -import android.content.Context -import com.swmansion.rnexecutorch.utils.ETError -import org.pytorch.executorch.EValue -import org.pytorch.executorch.Module -import org.pytorch.executorch.Tensor - -abstract class BaseModel( - val context: Context, -) { - protected lateinit var module: Module - - fun loadModel(modelSource: String) { - module = Module.load(modelSource) - } - - protected fun forward(vararg inputs: EValue): Array { - try { - val result = module.forward(*inputs) - return result - } catch (e: IllegalArgumentException) { - // The error is thrown when transformation to Tensor fails - throw Error(ETError.InvalidArgument.code.toString()) - } catch (e: Exception) { - throw Error(e.message) - } - } - - protected fun forward( - inputs: Array, - shapes: Array, - ): Array = this.execute("forward", inputs, shapes) - - protected fun execute( - methodName: String, - inputs: Array, - shapes: Array, - ): Array { - // We want to convert each input to EValue, a data structure accepted by ExecuTorch's - // Module. The array below keeps track of that values. - try { - val executorchInputs = inputs.mapIndexed { index, _ -> EValue.from(Tensor.fromBlob(inputs[index], shapes[index])) } - val forwardResult = module.execute(methodName, *executorchInputs.toTypedArray()) - return forwardResult - } catch (e: IllegalArgumentException) { - throw Error(ETError.InvalidArgument.code.toString()) - } catch (e: Exception) { - throw Error(e.message) - } - } - - abstract fun runModel(input: Input): Output -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt deleted file mode 100644 index 2d17cf44d8..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt +++ /dev/null @@ -1,82 +0,0 @@ -package com.swmansion.rnexecutorch.models.ocr - -import com.facebook.react.bridge.ReactApplicationContext -import com.swmansion.rnexecutorch.models.BaseModel -import com.swmansion.rnexecutorch.models.ocr.utils.Constants -import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils -import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.core.Mat -import org.opencv.core.Size -import org.pytorch.executorch.EValue - -class Detector( - reactApplicationContext: ReactApplicationContext, -) : BaseModel>(reactApplicationContext) { - private lateinit var originalSize: Size - - fun getModelImageSize(): Size { - val inputShape = module.getInputShape(0) - val width = inputShape[inputShape.lastIndex - 1] - val height = inputShape[inputShape.lastIndex] - - val modelImageSize = Size(height.toDouble(), width.toDouble()) - - return modelImageSize - } - - fun preprocess(input: Mat): EValue { - originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) - val resizedImage = - ImageProcessor.resizeWithPadding( - input, - getModelImageSize().width.toInt(), - getModelImageSize().height.toInt(), - ) - - return ImageProcessor.matToEValue( - resizedImage, - module.getInputShape(0), - Constants.MEAN, - Constants.VARIANCE, - ) - } - - fun postprocess(output: Array): List { - val outputTensor = output[0].toTensor() - val outputArray = outputTensor.dataAsFloatArray - val modelImageSize = getModelImageSize() - - val (scoreText, scoreLink) = - DetectorUtils.interleavedArrayToMats( - outputArray, - Size(modelImageSize.width / 2, modelImageSize.height / 2), - ) - var bBoxesList = - DetectorUtils.getDetBoxesFromTextMap( - scoreText, - scoreLink, - Constants.TEXT_THRESHOLD, - Constants.LINK_THRESHOLD, - Constants.LOW_TEXT_THRESHOLD, - ) - - bBoxesList = - DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) - - bBoxesList = - DetectorUtils.groupTextBoxes( - bBoxesList, - Constants.CENTER_THRESHOLD, - Constants.DISTANCE_THRESHOLD, - Constants.HEIGHT_THRESHOLD, - Constants.MIN_SIDE_THRESHOLD, - Constants.MAX_SIDE_THRESHOLD, - Constants.MAX_WIDTH, - ) - - return bBoxesList.toList() - } - - override fun runModel(input: Mat): List = postprocess(forward(preprocess(input))) -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt deleted file mode 100644 index 356168d2bc..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt +++ /dev/null @@ -1,117 +0,0 @@ -package com.swmansion.rnexecutorch.models.ocr - -import com.facebook.react.bridge.Arguments -import com.facebook.react.bridge.ReactApplicationContext -import com.facebook.react.bridge.WritableArray -import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter -import com.swmansion.rnexecutorch.models.ocr.utils.Constants -import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox -import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.core.Core -import org.opencv.core.Mat - -class RecognitionHandler( - symbols: String, - reactApplicationContext: ReactApplicationContext, -) { - private val recognizerLarge = Recognizer(reactApplicationContext) - private val recognizerMedium = Recognizer(reactApplicationContext) - private val recognizerSmall = Recognizer(reactApplicationContext) - private val converter = CTCLabelConverter(symbols) - - private fun runModel(croppedImage: Mat): Pair, Double> { - val result: Pair, Double> = - if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) { - recognizerLarge.runModel(croppedImage) - } else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) { - recognizerMedium.runModel(croppedImage) - } else { - recognizerSmall.runModel(croppedImage) - } - - return result - } - - fun loadRecognizers( - largeRecognizerPath: String, - mediumRecognizerPath: String, - smallRecognizerPath: String, - onComplete: (Int, Exception?) -> Unit, - ) { - try { - recognizerLarge.loadModel(largeRecognizerPath) - recognizerMedium.loadModel(mediumRecognizerPath) - recognizerSmall.loadModel(smallRecognizerPath) - onComplete(0, null) - } catch (e: Exception) { - onComplete(1, e) - } - } - - fun recognize( - bBoxesList: List, - imgGray: Mat, - desiredWidth: Int, - desiredHeight: Int, - ): WritableArray { - val res: WritableArray = Arguments.createArray() - val ratioAndPadding = - RecognizerUtils.calculateResizeRatioAndPaddings( - imgGray.width(), - imgGray.height(), - desiredWidth, - desiredHeight, - ) - - val left = ratioAndPadding["left"] as Int - val top = ratioAndPadding["top"] as Int - val resizeRatio = ratioAndPadding["resizeRatio"] as Float - val resizedImg = - ImageProcessor.resizeWithPadding( - imgGray, - desiredWidth, - desiredHeight, - ) - - for (box in bBoxesList) { - var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT) - if (croppedImage.empty()) { - continue - } - - croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST) - - var result = runModel(croppedImage) - var confidenceScore = result.second - - if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) { - Core.rotate(croppedImage, croppedImage, Core.ROTATE_180) - val rotatedResult = runModel(croppedImage) - val rotatedConfidenceScore = rotatedResult.second - if (rotatedConfidenceScore > confidenceScore) { - result = rotatedResult - confidenceScore = rotatedConfidenceScore - } - } - - val predIndex = result.first - val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size) - - for (bBox in box.bBox) { - bBox.x = (bBox.x - left) * resizeRatio - bBox.y = (bBox.y - top) * resizeRatio - } - - val resMap = Arguments.createMap() - - resMap.putString("text", decodedTexts[0]) - resMap.putArray("bbox", box.toWritableArray()) - resMap.putDouble("score", confidenceScore) - - res.pushMap(resMap) - } - - return res - } -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt deleted file mode 100644 index 1f6ea14fed..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt +++ /dev/null @@ -1,51 +0,0 @@ -package com.swmansion.rnexecutorch.models.ocr - -import com.facebook.react.bridge.ReactApplicationContext -import com.swmansion.rnexecutorch.models.BaseModel -import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.core.Mat -import org.opencv.core.Size -import org.pytorch.executorch.EValue - -class Recognizer( - reactApplicationContext: ReactApplicationContext, -) : BaseModel, Double>>(reactApplicationContext) { - private fun getModelOutputSize(): Size { - val outputShape = module.getOutputShape(0) - val width = outputShape[outputShape.lastIndex] - val height = outputShape[outputShape.lastIndex - 1] - - return Size(height.toDouble(), width.toDouble()) - } - - fun preprocess(input: Mat): EValue = ImageProcessor.matToEValueGray(input) - - fun postprocess(output: Array): Pair, Double> { - val modelOutputHeight = getModelOutputSize().height.toInt() - val tensor = output[0].toTensor().dataAsFloatArray - val numElements = tensor.size - val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight - val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F) - var counter = 0 - var currentRow = 0 - for (num in tensor) { - resultMat.put(currentRow, counter, floatArrayOf(num)) - counter++ - if (counter >= modelOutputHeight) { - counter = 0 - currentRow++ - } - } - - var probabilities = RecognizerUtils.softmax(resultMat) - val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight) - probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm) - val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities) - - val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices) - return Pair(indices, confidenceScore) - } - - override fun runModel(input: Mat): Pair, Double> = postprocess(module.forward(preprocess(input))) -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt deleted file mode 100644 index 3d5d7aea17..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt +++ /dev/null @@ -1,89 +0,0 @@ -package com.swmansion.rnexecutorch.models.ocr - -import com.facebook.react.bridge.ReactApplicationContext -import com.swmansion.rnexecutorch.models.BaseModel -import com.swmansion.rnexecutorch.models.ocr.utils.Constants -import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils -import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.core.Mat -import org.opencv.core.Size -import org.pytorch.executorch.EValue - -class VerticalDetector( - private val detectSingleCharacter: Boolean, - reactApplicationContext: ReactApplicationContext, -) : BaseModel>(reactApplicationContext) { - private lateinit var originalSize: Size - - fun getModelImageSize(): Size { - val inputShape = module.getInputShape(0) - val width = inputShape[inputShape.lastIndex - 1] - val height = inputShape[inputShape.lastIndex] - - val modelImageSize = Size(height.toDouble(), width.toDouble()) - - return modelImageSize - } - - fun preprocess(input: Mat): EValue { - originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) - val resizedImage = - ImageProcessor.resizeWithPadding( - input, - getModelImageSize().width.toInt(), - getModelImageSize().height.toInt(), - ) - - return ImageProcessor.matToEValue( - resizedImage, - module.getInputShape(0), - Constants.MEAN, - Constants.VARIANCE, - ) - } - - fun postprocess(output: Array): List { - val outputTensor = output[0].toTensor() - val outputArray = outputTensor.dataAsFloatArray - val modelImageSize = getModelImageSize() - - val (scoreText, scoreLink) = - DetectorUtils.interleavedArrayToMats( - outputArray, - Size(modelImageSize.width / 2, modelImageSize.height / 2), - ) - - val txtThreshold = if (detectSingleCharacter) Constants.TEXT_THRESHOLD else Constants.TEXT_THRESHOLD_VERTICAL - var bBoxesList = - DetectorUtils.getDetBoxesFromTextMapVertical( - scoreText, - scoreLink, - txtThreshold, - Constants.LINK_THRESHOLD, - detectSingleCharacter, - ) - - bBoxesList = - DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RESTORE_RATIO_VERTICAL).toFloat()) - - if (detectSingleCharacter) { - return bBoxesList - } - - bBoxesList = - DetectorUtils.groupTextBoxes( - bBoxesList, - Constants.CENTER_THRESHOLD, - Constants.DISTANCE_THRESHOLD, - Constants.HEIGHT_THRESHOLD, - Constants.MIN_SIDE_THRESHOLD, - Constants.MAX_SIDE_THRESHOLD, - Constants.MAX_WIDTH, - ) - - return bBoxesList.toList() - } - - override fun runModel(input: Mat): List = postprocess(forward(preprocess(input))) -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt deleted file mode 100644 index b12538c231..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt +++ /dev/null @@ -1,58 +0,0 @@ -package com.swmansion.rnexecutorch.models.ocr.utils - -class CTCLabelConverter( - characters: String, -) { - private val dict = mutableMapOf() - private val character: List - private val ignoreIdx: List - - init { - val mutableCharacters = mutableListOf("[blank]") - characters.forEachIndexed { index, char -> - mutableCharacters.add(char.toString()) - dict[char.toString()] = index + 1 - } - character = mutableCharacters.toList() - - val ignoreIndexes = mutableListOf(0) - - ignoreIdx = ignoreIndexes.toList() - } - - fun decodeGreedy( - textIndex: List, - length: Int, - ): List { - val texts = mutableListOf() - var index = 0 - while (index < textIndex.size) { - val segmentLength = minOf(length, textIndex.size - index) - val subArray = textIndex.subList(index, index + segmentLength) - - val text = StringBuilder() - var lastChar: Int? = null - val isNotRepeated = mutableListOf(true) - val isNotIgnored = mutableListOf() - - subArray.forEachIndexed { i, currentChar -> - if (i > 0) { - isNotRepeated.add(lastChar != currentChar) - } - isNotIgnored.add(!ignoreIdx.contains(currentChar)) - lastChar = currentChar - } - - subArray.forEachIndexed { j, charIndex -> - if (isNotRepeated[j] && isNotIgnored[j]) { - text.append(character[charIndex]) - } - } - - texts.add(text.toString()) - index += segmentLength - if (segmentLength < length) break - } - return texts.toList() - } -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt deleted file mode 100644 index 5dc25cd796..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt +++ /dev/null @@ -1,31 +0,0 @@ -package com.swmansion.rnexecutorch.models.ocr.utils - -import org.opencv.core.Scalar - -class Constants { - companion object { - const val RECOGNIZER_RATIO = 1.6 - const val RESTORE_RATIO_VERTICAL = 2.0 - const val MODEL_HEIGHT = 64 - const val LARGE_MODEL_WIDTH = 512 - const val MEDIUM_MODEL_WIDTH = 256 - const val SMALL_MODEL_WIDTH = 128 - const val VERTICAL_SMALL_MODEL_WIDTH = 64 - const val LOW_CONFIDENCE_THRESHOLD = 0.3 - const val ADJUST_CONTRAST = 0.2 - const val TEXT_THRESHOLD = 0.4 - const val TEXT_THRESHOLD_VERTICAL = 0.3 - const val LINK_THRESHOLD = 0.4 - const val LOW_TEXT_THRESHOLD = 0.7 - const val CENTER_THRESHOLD = 0.5 - const val DISTANCE_THRESHOLD = 2.0 - const val HEIGHT_THRESHOLD = 2.0 - const val MIN_SIDE_THRESHOLD = 15 - const val MAX_SIDE_THRESHOLD = 30 - const val MAX_WIDTH = (LARGE_MODEL_WIDTH + (LARGE_MODEL_WIDTH * 0.15)).toInt() - const val MIN_SIZE = 20 - const val SINGLE_CHARACTER_MIN_SIZE = 70 - val MEAN = Scalar(0.485, 0.456, 0.406) - val VARIANCE = Scalar(0.229, 0.224, 0.225) - } -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt deleted file mode 100644 index 7a0d5e9e0d..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt +++ /dev/null @@ -1,608 +0,0 @@ -package com.swmansion.rnexecutorch.models.ocr.utils - -import com.facebook.react.bridge.Arguments -import com.facebook.react.bridge.WritableArray -import org.opencv.core.Core -import org.opencv.core.CvType -import org.opencv.core.Mat -import org.opencv.core.MatOfFloat4 -import org.opencv.core.MatOfInt -import org.opencv.core.MatOfPoint -import org.opencv.core.MatOfPoint2f -import org.opencv.core.Point -import org.opencv.core.Rect -import org.opencv.core.Scalar -import org.opencv.core.Size -import org.opencv.imgproc.Imgproc -import kotlin.math.abs -import kotlin.math.atan -import kotlin.math.cos -import kotlin.math.max -import kotlin.math.min -import kotlin.math.pow -import kotlin.math.sin -import kotlin.math.sqrt - -class DetectorUtils { - companion object { - private fun normalizeAngle(angle: Double): Double { - if (angle > 45.0) { - return angle - 90.0 - } - - return angle - } - - private fun midpoint( - p1: BBoxPoint, - p2: BBoxPoint, - ): BBoxPoint { - val midpoint = BBoxPoint((p1.x + p2.x) / 2, (p1.y + p2.y) / 2) - return midpoint - } - - private fun distanceBetweenPoints( - p1: BBoxPoint, - p2: BBoxPoint, - ): Double = sqrt((p1.x - p2.x).pow(2.0) + (p1.y - p2.y).pow(2.0)) - - private fun centerOfBox(box: OCRbBox): BBoxPoint { - val p1 = box.bBox[0] - val p2 = box.bBox[2] - return midpoint(p1, p2) - } - - private fun maxSideLength(box: OCRbBox): Double { - var maxSideLength = 0.0 - val numOfPoints = box.bBox.size - for (i in 0 until numOfPoints) { - val currentPoint = box.bBox[i] - val nextPoint = box.bBox[(i + 1) % numOfPoints] - val sideLength = distanceBetweenPoints(currentPoint, nextPoint) - if (sideLength > maxSideLength) { - maxSideLength = sideLength - } - } - return maxSideLength - } - - private fun minSideLength(box: OCRbBox): Double { - var minSideLength = Double.MAX_VALUE - val numOfPoints = box.bBox.size - for (i in 0 until numOfPoints) { - val currentPoint = box.bBox[i] - val nextPoint = box.bBox[(i + 1) % numOfPoints] - val sideLength = distanceBetweenPoints(currentPoint, nextPoint) - if (sideLength < minSideLength) { - minSideLength = sideLength - } - } - return minSideLength - } - - private fun calculateMinimalDistanceBetweenBoxes( - box1: OCRbBox, - box2: OCRbBox, - ): Double { - var minDistance = Double.MAX_VALUE - for (i in 0 until 4) { - for (j in 0 until 4) { - val distance = distanceBetweenPoints(box1.bBox[i], box2.bBox[j]) - if (distance < minDistance) { - minDistance = distance - } - } - } - - return minDistance - } - - private fun rotateBox( - box: OCRbBox, - angle: Double, - ): OCRbBox { - val center = centerOfBox(box) - val radians = angle * Math.PI / 180 - val newBBox = - box.bBox.map { point -> - val translatedX = point.x - center.x - val translatedY = point.y - center.y - val rotatedX = translatedX * cos(radians) - translatedY * sin(radians) - val rotatedY = translatedX * sin(radians) + translatedY * cos(radians) - BBoxPoint(rotatedX + center.x, rotatedY + center.y) - } - - return OCRbBox(newBBox, box.angle) - } - - private fun orderPointsClockwise(box: OCRbBox): OCRbBox { - var topLeft = box.bBox[0] - var topRight = box.bBox[1] - var bottomRight = box.bBox[2] - var bottomLeft = box.bBox[3] - var minSum = Double.MAX_VALUE - var maxSum = -Double.MAX_VALUE - var minDiff = Double.MAX_VALUE - var maxDiff = -Double.MAX_VALUE - - for (point in box.bBox) { - val sum = point.x + point.y - val diff = point.x - point.y - if (sum < minSum) { - minSum = sum - topLeft = point - } - if (sum > maxSum) { - maxSum = sum - bottomRight = point - } - if (diff < minDiff) { - minDiff = diff - bottomLeft = point - } - if (diff > maxDiff) { - maxDiff = diff - topRight = point - } - } - - return OCRbBox(listOf(topLeft, topRight, bottomRight, bottomLeft), box.angle) - } - - private fun mergeRotatedBoxes( - box1: OCRbBox, - box2: OCRbBox, - ): OCRbBox { - val orderedBox1 = orderPointsClockwise(box1) - val orderedBox2 = orderPointsClockwise(box2) - - val allPoints = arrayListOf() - allPoints.addAll(orderedBox1.bBox.map { Point(it.x, it.y) }) - allPoints.addAll(orderedBox2.bBox.map { Point(it.x, it.y) }) - - val matOfAllPoints = MatOfPoint() - matOfAllPoints.fromList(allPoints) - - val hullIndices = MatOfInt() - Imgproc.convexHull(matOfAllPoints, hullIndices, false) - - val hullPoints = hullIndices.toArray().map { allPoints[it] } - - val matOfHullPoints = MatOfPoint2f() - matOfHullPoints.fromList(hullPoints) - - val minAreaRect = Imgproc.minAreaRect(matOfHullPoints) - val rectPoints = arrayOfNulls(4) - minAreaRect.points(rectPoints) - - val bBoxPoints = rectPoints.filterNotNull().map { BBoxPoint(it.x, it.y) } - - return OCRbBox(bBoxPoints, minAreaRect.angle) - } - - private fun removeSmallBoxes( - boxes: MutableList, - minSideThreshold: Int, - maxSideThreshold: Int, - ): MutableList = - boxes - .filter { minSideLength(it) > minSideThreshold && maxSideLength(it) > maxSideThreshold } - .toMutableList() - - private fun minimumYFromBox(box: List): Double = box.minOf { it.y } - - private fun fitLineToShortestSides(box: OCRbBox): LineInfo { - val sides = mutableListOf>() - val midpoints = mutableListOf() - - for (i in box.bBox.indices) { - val p1 = box.bBox[i] - val p2 = box.bBox[(i + 1) % 4] - val sideLength = distanceBetweenPoints(p1, p2) - sides.add(sideLength to i) - midpoints.add(midpoint(p1, p2)) - } - - sides.sortBy { it.first } - - val midpoint1 = midpoints[sides[0].second] - val midpoint2 = midpoints[sides[1].second] - - val dx = abs(midpoint2.x - midpoint1.x) - val line = MatOfFloat4() - - val isVertical = - if (dx < 20) { - for (point in arrayOf(midpoint1, midpoint2)) { - val temp = point.x - point.x = point.y - point.y = temp - } - Imgproc.fitLine( - MatOfPoint2f( - Point(midpoint1.x, midpoint1.y), - Point(midpoint2.x, midpoint2.y), - ), - line, - Imgproc.DIST_L2, - 0.0, - 0.01, - 0.01, - ) - true - } else { - Imgproc.fitLine( - MatOfPoint2f( - Point(midpoint1.x, midpoint1.y), - Point(midpoint2.x, midpoint2.y), - ), - line, - Imgproc.DIST_L2, - 0.0, - 0.01, - 0.01, - ) - false - } - - val m = line.get(1, 0)[0] / line.get(0, 0)[0] // slope - val c = line.get(3, 0)[0] - m * line.get(2, 0)[0] // intercept - return LineInfo(m, c, isVertical) - } - - private fun findClosestBox( - boxes: MutableList, - ignoredIds: Set, - currentBox: OCRbBox, - isVertical: Boolean, - m: Double, - c: Double, - centerThreshold: Double, - ): Pair? { - var smallestDistance = Double.MAX_VALUE - var idx = -1 - var boxHeight = 0.0 - val centerOfCurrentBox = centerOfBox(currentBox) - boxes.forEachIndexed { i, box -> - if (ignoredIds.contains(i)) { - return@forEachIndexed - } - val centerOfProcessedBox = centerOfBox(box) - val distanceBetweenCenters = distanceBetweenPoints(centerOfCurrentBox, centerOfProcessedBox) - if (distanceBetweenCenters >= smallestDistance) { - return@forEachIndexed - } - boxHeight = minSideLength(box) - val lineDistance = - if (isVertical) { - abs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) - } else { - abs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c)) - } - - if (lineDistance < boxHeight * centerThreshold) { - idx = i - smallestDistance = distanceBetweenCenters - } - } - - return if (idx == -1) null else Pair(idx, boxHeight) - } - - private fun createMaskFromLabels( - labels: Mat, - labelValue: Int, - ): Mat { - val mask = Mat.zeros(labels.size(), CvType.CV_8U) - - Core.compare(labels, Scalar(labelValue.toDouble()), mask, Core.CMP_EQ) - - return mask - } - - fun interleavedArrayToMats( - array: FloatArray, - size: Size, - ): Pair { - val mat1 = Mat(size.height.toInt(), size.width.toInt(), CvType.CV_32F) - val mat2 = Mat(size.height.toInt(), size.width.toInt(), CvType.CV_32F) - - array.forEachIndexed { index, value -> - val x = (index / 2) % (size.width.toInt()) - val y = (index / 2) / size.width.toInt() - if (index % 2 == 0) { - mat1.put(y, x, value.toDouble()) - } else { - mat2.put(y, x, value.toDouble()) - } - } - - return Pair(mat1, mat2) - } - - fun getDetBoxesFromTextMapVertical( - textMap: Mat, - affinityMap: Mat, - textThreshold: Double, - linkThreshold: Double, - independentCharacters: Boolean, - ): MutableList { - val imgH = textMap.rows() - val imgW = textMap.cols() - - val textScore = Mat() - val affinityScore = Mat() - Imgproc.threshold(textMap, textScore, textThreshold, 1.0, Imgproc.THRESH_BINARY) - Imgproc.threshold(affinityMap, affinityScore, linkThreshold, 1.0, Imgproc.THRESH_BINARY) - val textScoreComb = Mat() - var kernel = - Imgproc.getStructuringElement( - Imgproc.MORPH_RECT, - Size(3.0, 3.0), - ) - if (independentCharacters) { - Core.subtract(textScore, affinityScore, textScoreComb) - Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 0.0, Imgproc.THRESH_TOZERO) - Imgproc.threshold(textScoreComb, textScoreComb, 1.0, 1.0, Imgproc.THRESH_TRUNC) - Imgproc.erode(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 1) - Imgproc.dilate(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 4) - } else { - Core.add(textScore, affinityScore, textScoreComb) - Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 0.0, Imgproc.THRESH_TOZERO) - Imgproc.threshold(textScoreComb, textScoreComb, 1.0, 1.0, Imgproc.THRESH_TRUNC) - Imgproc.dilate(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 2) - } - - val binaryMat = Mat() - textScoreComb.convertTo(binaryMat, CvType.CV_8UC1) - - val labels = Mat() - val stats = Mat() - val centroids = Mat() - val nLabels = Imgproc.connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4) - - val detectedBoxes = mutableListOf() - for (i in 1 until nLabels) { - val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() - if (area < Constants.MIN_SIZE) continue - - val height = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() - val width = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() - - if (!independentCharacters && height < width) continue - val mask = createMaskFromLabels(labels, i) - - val segMap = Mat.zeros(textMap.size(), CvType.CV_8U) - segMap.setTo(Scalar(255.0), mask) - - val x = stats.get(i, Imgproc.CC_STAT_LEFT)[0].toInt() - val y = stats.get(i, Imgproc.CC_STAT_TOP)[0].toInt() - val w = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() - val h = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() - val dilationRadius = (sqrt(area / max(w, h).toDouble()) * 2.0).toInt() - val sx = max(x - dilationRadius, 0) - val ex = min(x + w + dilationRadius + 1, imgW) - val sy = max(y - dilationRadius, 0) - val ey = min(y + h + dilationRadius + 1, imgH) - val roi = Rect(sx, sy, ex - sx, ey - sy) - kernel = - Imgproc.getStructuringElement( - Imgproc.MORPH_RECT, - Size((1 + dilationRadius).toDouble(), (1 + dilationRadius).toDouble()), - ) - val roiSegMap = Mat(segMap, roi) - Imgproc.dilate(roiSegMap, roiSegMap, kernel, Point(-1.0, -1.0), 2) - - val contours: List = ArrayList() - Imgproc.findContours( - segMap, - contours, - Mat(), - Imgproc.RETR_EXTERNAL, - Imgproc.CHAIN_APPROX_SIMPLE, - ) - if (contours.isNotEmpty()) { - val minRect = Imgproc.minAreaRect(MatOfPoint2f(*contours[0].toArray())) - val points = Array(4) { Point() } - minRect.points(points) - val pointsList = points.map { point -> BBoxPoint(point.x, point.y) } - val boxInfo = OCRbBox(pointsList, minRect.angle) - detectedBoxes.add(boxInfo) - } - } - - return detectedBoxes - } - - fun getDetBoxesFromTextMap( - textMap: Mat, - affinityMap: Mat, - textThreshold: Double, - linkThreshold: Double, - lowTextThreshold: Double, - ): MutableList { - val imgH = textMap.rows() - val imgW = textMap.cols() - - val textScore = Mat() - val affinityScore = Mat() - Imgproc.threshold(textMap, textScore, textThreshold, 1.0, Imgproc.THRESH_BINARY) - Imgproc.threshold(affinityMap, affinityScore, linkThreshold, 1.0, Imgproc.THRESH_BINARY) - val textScoreComb = Mat() - Core.add(textScore, affinityScore, textScoreComb) - Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 1.0, Imgproc.THRESH_BINARY) - - val binaryMat = Mat() - textScoreComb.convertTo(binaryMat, CvType.CV_8UC1) - - val labels = Mat() - val stats = Mat() - val centroids = Mat() - val nLabels = Imgproc.connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4) - - val detectedBoxes = mutableListOf() - for (i in 1 until nLabels) { - val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() - if (area < 10) continue - val mask = createMaskFromLabels(labels, i) - val maxValResult = Core.minMaxLoc(textMap, mask) - val maxVal = maxValResult.maxVal - if (maxVal < lowTextThreshold) continue - val segMap = Mat.zeros(textMap.size(), CvType.CV_8U) - segMap.setTo(Scalar(255.0), mask) - - val x = stats.get(i, Imgproc.CC_STAT_LEFT)[0].toInt() - val y = stats.get(i, Imgproc.CC_STAT_TOP)[0].toInt() - val w = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() - val h = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() - val dilationRadius = (sqrt(area / max(w, h).toDouble()) * 2.0).toInt() - val sx = max(x - dilationRadius, 0) - val ex = min(x + w + dilationRadius + 1, imgW) - val sy = max(y - dilationRadius, 0) - val ey = min(y + h + dilationRadius + 1, imgH) - val roi = Rect(sx, sy, ex - sx, ey - sy) - val kernel = - Imgproc.getStructuringElement( - Imgproc.MORPH_RECT, - Size((1 + dilationRadius).toDouble(), (1 + dilationRadius).toDouble()), - ) - val roiSegMap = Mat(segMap, roi) - Imgproc.dilate(roiSegMap, roiSegMap, kernel) - - val contours: List = ArrayList() - Imgproc.findContours( - segMap, - contours, - Mat(), - Imgproc.RETR_EXTERNAL, - Imgproc.CHAIN_APPROX_SIMPLE, - ) - if (contours.isNotEmpty()) { - val minRect = Imgproc.minAreaRect(MatOfPoint2f(*contours[0].toArray())) - val points = Array(4) { Point() } - minRect.points(points) - val pointsList = points.map { point -> BBoxPoint(point.x, point.y) } - val boxInfo = OCRbBox(pointsList, minRect.angle) - detectedBoxes.add(boxInfo) - } - } - - return detectedBoxes - } - - fun restoreBoxRatio( - boxes: MutableList, - restoreRatio: Float, - ): MutableList { - for (box in boxes) { - for (b in box.bBox) { - b.x *= restoreRatio - b.y *= restoreRatio - } - } - - return boxes - } - - fun groupTextBoxes( - boxes: MutableList, - centerThreshold: Double, - distanceThreshold: Double, - heightThreshold: Double, - minSideThreshold: Int, - maxSideThreshold: Int, - maxWidth: Int, - ): MutableList { - boxes.sortByDescending { maxSideLength(it) } - var mergedArray = mutableListOf() - - while (boxes.isNotEmpty()) { - var currentBox = boxes.removeAt(0) - val normalizedAngle = normalizeAngle(currentBox.angle) - val ignoredIds = mutableSetOf() - var lineAngle: Double - while (true) { - val fittedLine = - fitLineToShortestSides(currentBox) - val slope = fittedLine.slope - val intercept = fittedLine.intercept - val isVertical = fittedLine.isVertical - - lineAngle = atan(slope) * 180 / Math.PI - if (isVertical) { - lineAngle = -90.0 - } - - val closestBoxInfo = - findClosestBox( - boxes, - ignoredIds, - currentBox, - isVertical, - slope, - intercept, - centerThreshold, - ) ?: break - - val candidateIdx = closestBoxInfo.first - var candidateBox = boxes[candidateIdx] - val candidateHeight = closestBoxInfo.second - if ((candidateBox.angle == 90.0 && !isVertical) || (candidateBox.angle == 0.0 && isVertical)) { - candidateBox = - rotateBox(candidateBox, normalizedAngle) - } - val minDistance = - calculateMinimalDistanceBetweenBoxes(candidateBox, currentBox) - val mergedHeight = minSideLength(currentBox) - if (minDistance < distanceThreshold * candidateHeight && - abs(mergedHeight - candidateHeight) < candidateHeight * heightThreshold - ) { - currentBox = mergeRotatedBoxes(currentBox, candidateBox) - boxes.removeAt(candidateIdx) - ignoredIds.clear() - if (maxSideLength(currentBox) > maxWidth) { - break - } - } else { - ignoredIds.add(candidateIdx) - } - } - mergedArray.add(currentBox.copy(angle = lineAngle)) - } - - mergedArray = removeSmallBoxes(mergedArray, minSideThreshold, maxSideThreshold) - mergedArray = mergedArray.sortedWith(compareBy { minimumYFromBox(it.bBox) }).toMutableList() - - mergedArray = mergedArray.map { box -> orderPointsClockwise(box) }.toMutableList() - - return mergedArray - } - } -} - -data class BBoxPoint( - var x: Double, - var y: Double, -) - -data class OCRbBox( - val bBox: List, - val angle: Double, -) { - fun toWritableArray(): WritableArray { - val array = Arguments.createArray() - bBox.forEach { point -> - val pointMap = Arguments.createMap() - pointMap.putDouble("x", point.x) - pointMap.putDouble("y", point.y) - array.pushMap(pointMap) - } - return array - } -} - -data class LineInfo( - val slope: Double, - val intercept: Double, - val isVertical: Boolean, -) diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt deleted file mode 100644 index 79245874fb..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +++ /dev/null @@ -1,430 +0,0 @@ -package com.swmansion.rnexecutorch.models.ocr.utils - -import com.swmansion.rnexecutorch.utils.ImageProcessor -import org.opencv.core.Core -import org.opencv.core.CvType -import org.opencv.core.Mat -import org.opencv.core.MatOfFloat -import org.opencv.core.MatOfInt -import org.opencv.core.MatOfPoint2f -import org.opencv.core.Point -import org.opencv.core.Rect -import org.opencv.core.Scalar -import org.opencv.core.Size -import org.opencv.imgproc.Imgproc -import kotlin.math.max -import kotlin.math.min -import kotlin.math.pow -import kotlin.math.sqrt - -class RecognizerUtils { - companion object { - private fun calculateRatio( - width: Int, - height: Int, - ): Double { - var ratio = width.toDouble() / height.toDouble() - if (ratio < 1.0) { - ratio = 1.0 / ratio - } - - return ratio - } - - private fun findIntersection( - r1: Rect, - r2: Rect, - ): Rect { - val aLeft = r1.x - val aTop = r1.y - val aRight = r1.x + r1.width - val aBottom = r1.y + r1.height - - val bLeft = r2.x - val bTop = r2.y - val bRight = r2.x + r2.width - val bBottom = r2.y + r2.height - - val iLeft = max(aLeft, bLeft) - val iTop = max(aTop, bTop) - val iRight = min(aRight, bRight) - val iBottom = min(aBottom, bBottom) - - return if (iRight > iLeft && iBottom > iTop) { - Rect(iLeft, iTop, iRight - iLeft, iBottom - iTop) - } else { - Rect() - } - } - - private fun adjustContrastGrey( - img: Mat, - target: Double, - ): Mat { - var high = 0 - var low = 255 - - for (i in 0 until img.rows()) { - for (j in 0 until img.cols()) { - val pixel = img.get(i, j)[0].toInt() - high = maxOf(high, pixel) - low = minOf(low, pixel) - } - } - - val contrast = (high - low) / 255.0 - - if (contrast < target) { - val ratio = 200.0 / maxOf(10, high - low) - val tempImg = Mat() - img.convertTo(tempImg, CvType.CV_32F) - Core.subtract(tempImg, Scalar(low.toDouble() - 25), tempImg) - Core.multiply(tempImg, Scalar(ratio), tempImg) - Imgproc.threshold(tempImg, tempImg, 255.0, 255.0, Imgproc.THRESH_TRUNC) - Imgproc.threshold(tempImg, tempImg, 0.0, 255.0, Imgproc.THRESH_TOZERO) - tempImg.convertTo(tempImg, CvType.CV_8U) - - return tempImg - } - - return img - } - - private fun computeRatioAndResize( - img: Mat, - width: Int, - height: Int, - modelHeight: Int, - ): Mat { - var ratio = width.toDouble() / height.toDouble() - - if (ratio < 1.0) { - ratio = - calculateRatio(width, height) - Imgproc.resize( - img, - img, - Size(modelHeight.toDouble(), (modelHeight * ratio)), - 0.0, - 0.0, - Imgproc.INTER_LANCZOS4, - ) - } else { - Imgproc.resize( - img, - img, - Size((modelHeight * ratio), modelHeight.toDouble()), - 0.0, - 0.0, - Imgproc.INTER_LANCZOS4, - ) - } - - return img - } - - fun softmax(inputs: Mat): Mat { - val maxVal = Mat() - Core.reduce(inputs, maxVal, 1, Core.REDUCE_MAX, CvType.CV_32F) - - val tiledMaxVal = Mat() - Core.repeat(maxVal, 1, inputs.width(), tiledMaxVal) - val expInputs = Mat() - Core.subtract(inputs, tiledMaxVal, expInputs) - Core.exp(expInputs, expInputs) - - val sumExp = Mat() - Core.reduce(expInputs, sumExp, 1, Core.REDUCE_SUM, CvType.CV_32F) - - val tiledSumExp = Mat() - Core.repeat(sumExp, 1, inputs.width(), tiledSumExp) - val softmaxOutput = Mat() - Core.divide(expInputs, tiledSumExp, softmaxOutput) - - return softmaxOutput - } - - fun sumProbabilityRows( - probabilities: Mat, - modelOutputHeight: Int, - ): FloatArray { - val predsNorm = FloatArray(probabilities.rows()) - - for (i in 0 until probabilities.rows()) { - var sum = 0.0 - for (j in 0 until modelOutputHeight) { - sum += probabilities.get(i, j)[0] - } - predsNorm[i] = sum.toFloat() - } - - return predsNorm - } - - fun divideMatrixByVector( - matrix: Mat, - vector: FloatArray, - ): Mat { - for (i in 0 until matrix.rows()) { - for (j in 0 until matrix.cols()) { - val value = matrix.get(i, j)[0] / vector[i] - matrix.put(i, j, value) - } - } - - return matrix - } - - fun findMaxValuesAndIndices(probabilities: Mat): Pair> { - val values = DoubleArray(probabilities.rows()) - val indices = mutableListOf() - - for (i in 0 until probabilities.rows()) { - val row = probabilities.row(i) - val minMaxLocResult = Core.minMaxLoc(row) - - values[i] = minMaxLocResult.maxVal - indices.add(minMaxLocResult.maxLoc.x.toInt()) - } - - return Pair(values, indices) - } - - fun computeConfidenceScore( - valuesArray: DoubleArray, - indicesArray: List, - ): Double { - val predsMaxProb = mutableListOf() - for ((index, value) in indicesArray.withIndex()) { - if (value != 0) predsMaxProb.add(valuesArray[index]) - } - - val nonZeroValues = - if (predsMaxProb.isEmpty()) doubleArrayOf(0.0) else predsMaxProb.toDoubleArray() - val product = nonZeroValues.reduce { acc, d -> acc * d } - val score = product.pow(2.0 / sqrt(nonZeroValues.size.toDouble())) - - return score - } - - fun calculateResizeRatioAndPaddings( - width: Int, - height: Int, - desiredWidth: Int, - desiredHeight: Int, - ): Map { - val newRatioH = desiredHeight.toFloat() / height - val newRatioW = desiredWidth.toFloat() / width - var resizeRatio = minOf(newRatioH, newRatioW) - - val newWidth = (width * resizeRatio).toInt() - val newHeight = (height * resizeRatio).toInt() - - val deltaW = desiredWidth - newWidth - val deltaH = desiredHeight - newHeight - - val top = deltaH / 2 - val left = deltaW / 2 - - val heightRatio = height.toFloat() / desiredHeight - val widthRatio = width.toFloat() / desiredWidth - - resizeRatio = maxOf(heightRatio, widthRatio) - - return mapOf( - "resizeRatio" to resizeRatio, - "top" to top, - "left" to left, - ) - } - - fun getCroppedImage( - box: OCRbBox, - image: Mat, - modelHeight: Int, - ): Mat { - val cords = box.bBox - val angle = box.angle - val points = ArrayList() - - cords.forEach { point -> - points.add(Point(point.x, point.y)) - } - - val rotatedRect = Imgproc.minAreaRect(MatOfPoint2f(*points.toTypedArray())) - val imageCenter = Point((image.cols() / 2.0), (image.rows() / 2.0)) - val rotationMatrix = Imgproc.getRotationMatrix2D(imageCenter, angle, 1.0) - val rotatedImage = Mat() - Imgproc.warpAffine(image, rotatedImage, rotationMatrix, image.size(), Imgproc.INTER_LINEAR) - - val rectPoints = Array(4) { Point() } - rotatedRect.points(rectPoints) - val transformedPoints = arrayOfNulls(4) - val rectMat = Mat(4, 2, CvType.CV_32FC2) - for (i in 0 until 4) { - rectMat.put(i, 0, *doubleArrayOf(rectPoints[i].x, rectPoints[i].y)) - } - Core.transform(rectMat, rectMat, rotationMatrix) - - for (i in 0 until 4) { - transformedPoints[i] = Point(rectMat.get(i, 0)[0], rectMat.get(i, 0)[1]) - } - - var boundingBox = - Imgproc.boundingRect(MatOfPoint2f(*transformedPoints.filterNotNull().toTypedArray())) - val validRegion = Rect(0, 0, rotatedImage.cols(), rotatedImage.rows()) - boundingBox = findIntersection(boundingBox, validRegion) - val croppedImage = Mat(rotatedImage, boundingBox) - if (croppedImage.empty()) { - return croppedImage - } - - return computeRatioAndResize(croppedImage, boundingBox.width, boundingBox.height, modelHeight) - } - - fun normalizeForRecognizer( - image: Mat, - adjustContrast: Double, - isVertical: Boolean = false, - ): Mat { - var img = image.clone() - - if (adjustContrast > 0) { - img = adjustContrastGrey(img, adjustContrast) - } - - val desiredWidth = - when { - img.width() >= Constants.LARGE_MODEL_WIDTH -> Constants.LARGE_MODEL_WIDTH - img.width() >= Constants.MEDIUM_MODEL_WIDTH -> Constants.MEDIUM_MODEL_WIDTH - else -> if (isVertical) Constants.VERTICAL_SMALL_MODEL_WIDTH else Constants.SMALL_MODEL_WIDTH - } - - img = ImageProcessor.resizeWithPadding(img, desiredWidth, Constants.MODEL_HEIGHT) - img.convertTo(img, CvType.CV_32F, 1.0 / 255.0) - Core.subtract(img, Scalar(0.5), img) - Core.multiply(img, Scalar(2.0), img) - - return img - } - - fun cropImageWithBoundingBox( - image: Mat, - box: List, - originalBox: List, - paddings: Map, - originalPaddings: Map, - ): Mat { - val topLeft = originalBox[0] - val points = arrayOfNulls(4) - - for (i in 0 until 4) { - val cords = box[i] - cords.x -= paddings["left"]!! as Int - cords.y -= paddings["top"]!! as Int - - cords.x *= paddings["resizeRatio"]!! as Float - cords.y *= paddings["resizeRatio"]!! as Float - - cords.x += topLeft.x - cords.y += topLeft.y - - cords.x -= originalPaddings["left"]!! as Int - cords.y -= (originalPaddings["top"]!! as Int) - - cords.x *= originalPaddings["resizeRatio"]!! as Float - cords.y *= originalPaddings["resizeRatio"]!! as Float - - cords.x = cords.x.coerceIn(0.0, (image.cols() - 1).toDouble()) - cords.y = cords.y.coerceIn(0.0, (image.rows() - 1).toDouble()) - - points[i] = Point(cords.x, cords.y) - } - - val boundingBox = Imgproc.boundingRect(MatOfPoint2f(*points)) - val croppedImage = Mat(image, boundingBox) - Imgproc.cvtColor(croppedImage, croppedImage, Imgproc.COLOR_BGR2GRAY) - Imgproc.resize(croppedImage, croppedImage, Size(64.0, 64.0), 0.0, 0.0, Imgproc.INTER_LANCZOS4) - Imgproc.medianBlur(croppedImage, croppedImage, 1) - - return croppedImage - } - - fun extractBoundingBox(cords: List): Rect { - val points = arrayOfNulls(4) - - for (i in 0 until 4) { - points[i] = Point(cords[i].x, cords[i].y) - } - - val boundingBox = Imgproc.boundingRect(MatOfPoint2f(*points)) - - return boundingBox - } - - fun cropSingleCharacter(img: Mat): Mat { - val histogram = Mat() - val histSize = MatOfInt(256) - val range = MatOfFloat(0f, 256f) - Imgproc.calcHist( - listOf(img), - MatOfInt(0), - Mat(), - histogram, - histSize, - range, - ) - - val midPoint = 256 / 2 - var sumLeft = 0.0 - var sumRight = 0.0 - for (i in 0 until midPoint) { - sumLeft += histogram.get(i, 0)[0] - } - for (i in midPoint until 256) { - sumRight += histogram.get(i, 0)[0] - } - - val thresholdType = if (sumLeft < sumRight) Imgproc.THRESH_BINARY_INV else Imgproc.THRESH_BINARY - - val thresh = Mat() - Imgproc.threshold(img, thresh, 0.0, 255.0, thresholdType + Imgproc.THRESH_OTSU) - - val labels = Mat() - val stats = Mat() - val centroids = Mat() - val numLabels = Imgproc.connectedComponentsWithStats(thresh, labels, stats, centroids, 8) - - val centralThreshold = 0.3 - val height = thresh.rows() - val width = thresh.cols() - val minX = centralThreshold * width - val maxX = (1 - centralThreshold) * width - val minY = centralThreshold * height - val maxY = (1 - centralThreshold) * height - - var selectedComponent = -1 - for (i in 1 until numLabels) { - val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() - val cx = centroids.get(i, 0)[0] - val cy = centroids.get(i, 1)[0] - if (cx > minX && cx < maxX && cy > minY && cy < maxY && area > Constants.SINGLE_CHARACTER_MIN_SIZE) { - if (selectedComponent == -1 || area > stats.get(selectedComponent, Imgproc.CC_STAT_AREA)[0]) { - selectedComponent = i - } - } - } - - val mask = Mat.zeros(img.size(), CvType.CV_8UC1) - if (selectedComponent != -1) { - Core.compare(labels, Scalar(selectedComponent.toDouble()), mask, Core.CMP_EQ) - } - - val resultImage = Mat.zeros(img.size(), img.type()) - img.copyTo(resultImage, mask) - - Core.bitwise_not(resultImage, resultImage) - return resultImage - } - } -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt deleted file mode 100644 index 352a3f0ae0..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +++ /dev/null @@ -1,87 +0,0 @@ -package com.swmansion.rnexecutorch.utils - -import com.facebook.react.bridge.Arguments -import com.facebook.react.bridge.ReadableArray -import com.facebook.react.bridge.WritableArray -import org.pytorch.executorch.DType -import org.pytorch.executorch.EValue -import org.pytorch.executorch.Tensor - -class ArrayUtils { - companion object { - inline fun createTypedArrayFromReadableArray( - input: ReadableArray, - transform: (ReadableArray, Int) -> T, - ): Array = Array(input.size()) { index -> transform(input, index) } - - fun createByteArray(input: ReadableArray): ByteArray = - createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray() - - fun createCharArray(input: ReadableArray): CharArray = - createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toChar() }.toCharArray() - - fun createIntArray(input: ReadableArray): IntArray = - createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index) }.toIntArray() - - fun createFloatArray(input: ReadableArray): FloatArray = - createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index).toFloat() }.toFloatArray() - - fun createLongArray(input: ReadableArray): LongArray = - createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toLong() }.toLongArray() - - fun createDoubleArray(input: ReadableArray): DoubleArray = - createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index) }.toDoubleArray() - - fun createReadableArrayFromTensor(result: Tensor): ReadableArray { - val resultArray = Arguments.createArray() - - when (result.dtype()) { - DType.UINT8 -> { - result.dataAsByteArray.forEach { resultArray.pushInt(it.toInt()) } - } - - DType.INT32 -> { - result.dataAsIntArray.forEach { resultArray.pushInt(it) } - } - - DType.FLOAT -> { - result.dataAsFloatArray.forEach { resultArray.pushDouble(it.toDouble()) } - } - - DType.DOUBLE -> { - result.dataAsDoubleArray.forEach { resultArray.pushDouble(it) } - } - - DType.INT64 -> { - // TODO: Do something to handle or deprecate long dtype - // https://github.com/facebook/react-native/issues/12506 - result.dataAsLongArray.forEach { resultArray.pushInt(it.toInt()) } - } - - else -> { - throw IllegalArgumentException("Invalid dtype: ${result.dtype()}") - } - } - - return resultArray - } - - fun createReadableArrayFromFloatArray(input: FloatArray): ReadableArray { - val resultArray = Arguments.createArray() - input.forEach { resultArray.pushDouble(it.toDouble()) } - return resultArray - } - - fun createReadableArrayFromIntArray(input: IntArray): ReadableArray { - val resultArray = Arguments.createArray() - input.forEach { resultArray.pushInt(it) } - return resultArray - } - - fun writableArrayToEValue(input: WritableArray): EValue { - val size = input.size() - val preprocessorInputShape = longArrayOf(1, size.toLong()) - return EValue.from(Tensor.fromBlob(createFloatArray(input), preprocessorInputShape)) - } - } -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ETError.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ETError.kt deleted file mode 100644 index f4cd62c633..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ETError.kt +++ /dev/null @@ -1,34 +0,0 @@ -package com.swmansion.rnexecutorch.utils - -enum class ETError( - val code: Int, -) { - UndefinedError(0x65), - ModuleNotLoaded(0x66), - FileWriteFailed(0x67), - InvalidModelSource(0xff), - - // System errors - Ok(0x00), - Internal(0x01), - InvalidState(0x02), - EndOfMethod(0x03), - - // Logical errors - NotSupported(0x10), - NotImplemented(0x11), - InvalidArgument(0x12), - InvalidType(0x13), - OperatorMissing(0x14), - - // Resource errors - NotFound(0x20), - MemoryAllocationFailed(0x21), - AccessFailed(0x22), - InvalidProgram(0x23), - - // Delegate errors - DelegateInvalidCompatibility(0x30), - DelegateMemoryAllocationFailed(0x31), - DelegateInvalidHandle(0x32), -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt deleted file mode 100644 index b8b262e700..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt +++ /dev/null @@ -1,237 +0,0 @@ -package com.swmansion.rnexecutorch.utils - -import android.content.Context -import android.net.Uri -import android.util.Base64 -import org.opencv.core.Core -import org.opencv.core.CvType -import org.opencv.core.Mat -import org.opencv.core.Scalar -import org.opencv.core.Size -import org.opencv.imgcodecs.Imgcodecs -import org.opencv.imgproc.Imgproc -import org.pytorch.executorch.EValue -import org.pytorch.executorch.Tensor -import java.io.File -import java.io.InputStream -import java.net.URL -import java.util.UUID -import kotlin.math.floor - -class ImageProcessor { - companion object { - fun matToEValue( - mat: Mat, - shape: LongArray, - ): EValue = matToEValue(mat, shape, Scalar(0.0, 0.0, 0.0), Scalar(1.0, 1.0, 1.0)) - - fun matToEValue( - mat: Mat, - shape: LongArray, - mean: Scalar, - variance: Scalar, - ): EValue { - val pixelCount = mat.cols() * mat.rows() - val floatArray = FloatArray(pixelCount * 3) - - for (i in 0 until pixelCount) { - val row = i / mat.cols() - val col = i % mat.cols() - val pixel = mat.get(row, col) - - if (mat.type() == CvType.CV_8UC3 || mat.type() == CvType.CV_8UC4) { - val b = (pixel[0] - mean.`val`[0] * 255.0f) / (variance.`val`[0] * 255.0f) - val g = (pixel[1] - mean.`val`[1] * 255.0f) / (variance.`val`[1] * 255.0f) - val r = (pixel[2] - mean.`val`[2] * 255.0f) / (variance.`val`[2] * 255.0f) - - floatArray[0 * pixelCount + i] = b.toFloat() - floatArray[1 * pixelCount + i] = g.toFloat() - floatArray[2 * pixelCount + i] = r.toFloat() - } - } - - return EValue.from(Tensor.fromBlob(floatArray, shape)) - } - - fun matToEValueGray(mat: Mat): EValue { - val pixelCount = mat.cols() * mat.rows() - val floatArray = FloatArray(pixelCount) - - for (i in 0 until pixelCount) { - val row = i / mat.cols() - val col = i % mat.cols() - val pixel = mat.get(row, col) - floatArray[i] = pixel[0].toFloat() - } - - return EValue.from( - Tensor.fromBlob( - floatArray, - longArrayOf(1, 1, mat.rows().toLong(), mat.cols().toLong()), - ), - ) - } - - fun eValueToMat( - array: FloatArray, - width: Int, - height: Int, - ): Mat { - val mat = Mat(height, width, CvType.CV_8UC3) - - val pixelCount = width * height - for (i in 0 until pixelCount) { - val row = i / width - val col = i % width - - val r = (array[i] * 255).toInt().toByte() - val g = (array[pixelCount + i] * 255).toInt().toByte() - val b = (array[2 * pixelCount + i] * 255).toInt().toByte() - - val color = byteArrayOf(b, g, r) - mat.put(row, col, color) - } - return mat - } - - fun saveToTempFile( - context: Context, - mat: Mat, - ): String { - try { - val uniqueID = UUID.randomUUID().toString() - val tempFile = File(context.cacheDir, "rn_executorch_$uniqueID.png") - Imgcodecs.imwrite(tempFile.absolutePath, mat) - - return "file://${tempFile.absolutePath}" - } catch (e: Exception) { - throw Exception(ETError.FileWriteFailed.toString()) - } - } - - fun readImage(source: String): Mat { - val inputImage: Mat - - val uri = Uri.parse(source) - val scheme = uri.scheme ?: "" - - when { - scheme.equals("data", ignoreCase = true) -> { - // base64 - val parts = source.split(",", limit = 2) - if (parts.size < 2) throw IllegalArgumentException(ETError.InvalidArgument.toString()) - - val encodedString = parts[1] - val data = Base64.decode(encodedString, Base64.DEFAULT) - - val encodedData = - Mat(1, data.size, CvType.CV_8UC1).apply { - put(0, 0, data) - } - inputImage = Imgcodecs.imdecode(encodedData, Imgcodecs.IMREAD_COLOR) - } - - scheme.equals("file", ignoreCase = true) -> { - // device storage - val path = uri.path - inputImage = Imgcodecs.imread(path, Imgcodecs.IMREAD_COLOR) - } - - else -> { - // external source - val url = URL(source) - val connection = url.openConnection() - connection.connect() - - val inputStream: InputStream = connection.getInputStream() - val data = inputStream.readBytes() - inputStream.close() - - val encodedData = - Mat(1, data.size, CvType.CV_8UC1).apply { - put(0, 0, data) - } - inputImage = Imgcodecs.imdecode(encodedData, Imgcodecs.IMREAD_COLOR) - } - } - - if (inputImage.empty()) { - throw IllegalArgumentException(ETError.InvalidArgument.toString()) - } - - return inputImage - } - - fun resizeWithPadding( - img: Mat, - desiredWidth: Int, - desiredHeight: Int, - ): Mat { - val height = img.rows() - val width = img.cols() - val heightRatio = desiredHeight.toFloat() / height - val widthRatio = desiredWidth.toFloat() / width - val resizeRatio = minOf(heightRatio, widthRatio) - val newWidth = (width * resizeRatio).toInt() - val newHeight = (height * resizeRatio).toInt() - - val resizedImg = Mat() - Imgproc.resize( - img, - resizedImg, - Size(newWidth.toDouble(), newHeight.toDouble()), - 0.0, - 0.0, - Imgproc.INTER_AREA, - ) - - val cornerPatchSize = maxOf(1, minOf(width, height) / 30) - val corners = - listOf( - img.submat(0, cornerPatchSize, 0, cornerPatchSize), - img.submat(0, cornerPatchSize, width - cornerPatchSize, width), - img.submat(height - cornerPatchSize, height, 0, cornerPatchSize), - img.submat(height - cornerPatchSize, height, width - cornerPatchSize, width), - ) - - var backgroundScalar = Core.mean(corners[0]) - for (i in 1 until corners.size) { - val mean = Core.mean(corners[i]) - backgroundScalar = - Scalar( - backgroundScalar.`val`[0] + mean.`val`[0], - backgroundScalar.`val`[1] + mean.`val`[1], - backgroundScalar.`val`[2] + mean.`val`[2], - ) - } - - backgroundScalar = - Scalar( - floor(backgroundScalar.`val`[0] / corners.size), - floor(backgroundScalar.`val`[1] / corners.size), - floor(backgroundScalar.`val`[2] / corners.size), - ) - - val deltaW = desiredWidth - newWidth - val deltaH = desiredHeight - newHeight - val top = deltaH / 2 - val bottom = deltaH - top - val left = deltaW / 2 - val right = deltaW - left - - val centeredImg = Mat() - Core.copyMakeBorder( - resizedImg, - centeredImg, - top, - bottom, - left, - right, - Core.BORDER_CONSTANT, - backgroundScalar, - ) - - return centeredImg - } - } -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt deleted file mode 100644 index 603699e35f..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/Numerical.kt +++ /dev/null @@ -1,8 +0,0 @@ -package com.swmansion.rnexecutorch.utils - -fun softmax(x: Array): Array { - val max = x.maxOrNull()!! - val exps = x.map { kotlin.math.exp(it - max) } - val sum = exps.sum() - return exps.map { it / sum }.toTypedArray() -} diff --git a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt b/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt deleted file mode 100644 index ca8552459c..0000000000 --- a/packages/react-native-executorch/android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt +++ /dev/null @@ -1,103 +0,0 @@ -package com.swmansion.rnexecutorch.utils - -import android.graphics.Bitmap -import android.graphics.Color -import com.facebook.react.bridge.ReadableArray -import org.pytorch.executorch.EValue -import org.pytorch.executorch.Tensor -import java.nio.FloatBuffer - -class TensorUtils { - companion object { - fun getExecutorchInput( - input: ReadableArray, - shape: LongArray, - type: Int, - ): EValue { - try { - when (type) { - 1 -> { - val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape) - return EValue.from(inputTensor) - } - 3 -> { - val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape) - return EValue.from(inputTensor) - } - 4 -> { - val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape) - return EValue.from(inputTensor) - } - 6 -> { - val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape) - return EValue.from(inputTensor) - } - 7 -> { - val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape) - return EValue.from(inputTensor) - } - - else -> { - throw IllegalArgumentException("Invalid input type: $type") - } - } - } catch (e: IllegalArgumentException) { - throw e - } - } - - fun float32TensorToBitmap(tensor: Tensor): Bitmap { - val shape = tensor.shape() // Assuming the tensor shape is [1, 3, H, W] - val height = shape[2].toInt() - val width = shape[3].toInt() - - val floatArray = tensor.dataAsFloatArray - - val bitmap = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) - val pixels = IntArray(width * height) - - val offsetG = height * width - val offsetB = 2 * height * width - - for (y in 0 until height) { - for (x in 0 until width) { - val r = Math.round(floatArray[y * width + x] * 255.0f) - val g = Math.round(floatArray[offsetG + y * width + x] * 255.0f) - val b = Math.round(floatArray[offsetB + y * width + x] * 255.0f) - pixels[y * width + x] = (0xFF shl 24) or (r shl 16) or (g shl 8) or b - } - } - - bitmap.setPixels(pixels, 0, width, 0, 0, width, height) - return bitmap - } - - fun bitmapToFloat32Tensor(bitmap: Bitmap): Tensor { - val height = bitmap.height - val width = bitmap.width - val floatBuffer = Tensor.allocateFloatBuffer(3 * width * height) - bitmapToFloatBuffer(bitmap, floatBuffer) - return Tensor.fromBlob(floatBuffer, longArrayOf(1, 3, height.toLong(), width.toLong())) - } - - private fun bitmapToFloatBuffer( - bitmap: Bitmap, - outBuffer: FloatBuffer, - ) { - val pixelsCount = bitmap.height * bitmap.width - val pixels = IntArray(pixelsCount) - bitmap.getPixels(pixels, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height) - val offsetG = pixelsCount - val offsetB = 2 * pixelsCount - for (i in 0 until pixelsCount) { - val c = pixels[i] - val r = Color.red(c) / 255.0f - val g = Color.green(c) / 255.0f - val b = Color.blue(c) / 255.0f - outBuffer.put(i, r) - outBuffer.put(offsetG + i, g) - outBuffer.put(offsetB + i, b) - } - } - } -} diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp index 9b00fa7f7f..7012597cbd 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp @@ -8,8 +8,10 @@ #include #include #include +#include #include #include +#include namespace rnexecutorch { @@ -63,14 +65,20 @@ void RnExecutorchInstaller::injectJSIBindings( RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadTextEmbeddings")); + jsiRuntime->global().setProperty( + *jsiRuntime, "loadSpeechToText", + RnExecutorchInstaller::loadModel(jsiRuntime, jsCallInvoker, + "loadSpeechToText")); jsiRuntime->global().setProperty(*jsiRuntime, "loadLLM", RnExecutorchInstaller::loadModel( jsiRuntime, jsCallInvoker, "loadLLM")); + jsiRuntime->global().setProperty(*jsiRuntime, "loadOCR", + RnExecutorchInstaller::loadModel( + jsiRuntime, jsCallInvoker, "loadOCR")); jsiRuntime->global().setProperty( - *jsiRuntime, "loadSpeechToText", - RnExecutorchInstaller::loadModel(jsiRuntime, jsCallInvoker, - "loadSpeechToText")); + *jsiRuntime, "loadVerticalOCR", + RnExecutorchInstaller::loadModel(jsiRuntime, jsCallInvoker, + "loadVerticalOCR")); } - } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h index 5c0d6ba629..cd24787e3f 100644 --- a/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h +++ b/packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h @@ -38,6 +38,10 @@ REGISTER_CONSTRUCTOR(LLM, std::string, std::string, std::shared_ptr); REGISTER_CONSTRUCTOR(SpeechToText, std::string, std::string, std::string, std::shared_ptr); +REGISTER_CONSTRUCTOR(OCR, std::string, std::string, std::string, std::string, + std::string, std::shared_ptr); +REGISTER_CONSTRUCTOR(VerticalOCR, std::string, std::string, std::string, + std::string, bool, std::shared_ptr); using namespace facebook; @@ -57,7 +61,6 @@ class RnExecutorchInstaller { loadModel(jsi::Runtime *jsiRuntime, std::shared_ptr jsCallInvoker, const std::string &loadFunctionName) { - return jsi::Function::createFromHostFunction( *jsiRuntime, jsi::PropNameID::forAscii(*jsiRuntime, loadFunctionName), 0, @@ -108,5 +111,4 @@ class RnExecutorchInstaller { }); } }; - } // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp index 4bbfc4e389..9e766fa01c 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.cpp @@ -115,8 +115,35 @@ cv::Mat readImage(const std::string &imageURI) { TensorPtr getTensorFromMatrix(const std::vector &tensorDims, const cv::Mat &matrix) { - std::vector inputVector = colorMatToVector(matrix); - return executorch::extension::make_tensor_ptr(tensorDims, inputVector); + return executorch::extension::make_tensor_ptr(tensorDims, + colorMatToVector(matrix)); +} + +TensorPtr getTensorFromMatrix(const std::vector &tensorDims, + const cv::Mat &matrix, cv::Scalar mean, + cv::Scalar variance) { + return executorch::extension::make_tensor_ptr( + tensorDims, colorMatToVector(matrix, mean, variance)); +} + +TensorPtr getTensorFromMatrixGray(const std::vector &tensorDims, + const cv::Mat &matrix) { + return executorch::extension::make_tensor_ptr(tensorDims, + grayMatToVector(matrix)); +} + +std::vector grayMatToVector(const cv::Mat &mat) { + CV_Assert(mat.type() == CV_32F); + if (mat.isContinuous()) { + return {mat.ptr(), mat.ptr() + mat.total()}; + } + + std::vector v; + v.reserve(mat.total()); + for (int i = 0; i < mat.rows; ++i) { + v.insert(v.end(), mat.ptr(i), mat.ptr(i) + mat.cols); + } + return v; } cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor) { @@ -125,9 +152,67 @@ cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor) { size); } +cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize) { + cv::Size inputSize = inputImage.size(); + const float heightRatio = + static_cast(targetSize.height) / inputSize.height; + const float widthRatio = + static_cast(targetSize.width) / inputSize.width; + const float resizeRatio = std::min(heightRatio, widthRatio); + const int newWidth = inputSize.width * resizeRatio; + const int newHeight = inputSize.height * resizeRatio; + cv::Mat resizedImg; + cv::resize(inputImage, resizedImg, cv::Size(newWidth, newHeight), 0, 0, + cv::INTER_AREA); + constexpr int minCornerPatchSize = 1; + constexpr int cornerPatchFractionSize = 30; + int cornerPatchSize = + std::min(inputSize.height, inputSize.width) / cornerPatchFractionSize; + cornerPatchSize = std::max(minCornerPatchSize, cornerPatchSize); + + const std::array corners = { + inputImage(cv::Rect(0, 0, cornerPatchSize, cornerPatchSize)), + inputImage(cv::Rect(inputSize.width - cornerPatchSize, 0, cornerPatchSize, + cornerPatchSize)), + inputImage(cv::Rect(0, inputSize.height - cornerPatchSize, + cornerPatchSize, cornerPatchSize)), + inputImage(cv::Rect(inputSize.width - cornerPatchSize, + inputSize.height - cornerPatchSize, cornerPatchSize, + cornerPatchSize))}; + + // We choose the color of the padding based on a mean of colors in the corners + // of an image. + cv::Scalar backgroundScalar = cv::mean(corners[0]); +#pragma unroll + for (size_t i = 1; i < corners.size(); i++) { + backgroundScalar += cv::mean(corners[i]); + } + backgroundScalar /= static_cast(corners.size()); + + constexpr size_t numChannels = 3; +#pragma unroll + for (size_t i = 0; i < numChannels; ++i) { + backgroundScalar[i] = cvFloor(backgroundScalar[i]); + } + + const int deltaW = targetSize.width - newWidth; + const int deltaH = targetSize.height - newHeight; + const int top = deltaH / 2; + const int bottom = deltaH - top; + const int left = deltaW / 2; + const int right = deltaW - left; + + cv::Mat centeredImg; + cv::copyMakeBorder(resizedImg, centeredImg, top, bottom, left, right, + cv::BORDER_CONSTANT, backgroundScalar); + + return centeredImg; +} + std::pair readImageToTensor(const std::string &path, - const std::vector &tensorDims) { + const std::vector &tensorDims, + bool maintainAspectRatio) { cv::Mat input = imageprocessing::readImage(path); cv::Size imageSize = input.size(); @@ -142,7 +227,11 @@ readImageToTensor(const std::string &path, cv::Size tensorSize = cv::Size(tensorDims[tensorDims.size() - 1], tensorDims[tensorDims.size() - 2]); - cv::resize(input, input, tensorSize); + if (maintainAspectRatio) { + input = resizePadded(input, tensorSize); + } else { + cv::resize(input, input, tensorSize); + } cv::cvtColor(input, input, cv::COLOR_BGR2RGB); diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h index 457c108203..ce3652e4eb 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/ImageProcessing.h @@ -1,15 +1,13 @@ #pragma once +#include +#include +#include #include #include #include #include -#include -#include - -#include - namespace rnexecutorch::imageprocessing { using executorch::aten::Tensor; using executorch::extension::TensorPtr; @@ -28,11 +26,30 @@ std::string saveToTempFile(const cv::Mat &image); cv::Mat readImage(const std::string &imageURI); TensorPtr getTensorFromMatrix(const std::vector &tensorDims, const cv::Mat &mat); +TensorPtr getTensorFromMatrix(const std::vector &tensorDims, + const cv::Mat &matrix, cv::Scalar mean, + cv::Scalar variance); cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor); +TensorPtr getTensorFromMatrixGray(const std::vector &tensorDims, + const cv::Mat &matrix); +std::vector grayMatToVector(const cv::Mat &mat); +/** + * @brief Resizes an image to fit within target dimensions while preserving + * aspect ratio, adding padding if needed. Padding color is derived from the + * image's corner pixels for seamless blending. + */ +cv::Mat resizePadded(const cv::Mat inputImage, cv::Size targetSize); /// @brief Read image, resize it and copy it to an ET tensor to store it. +/// @param path Path to the image to be resized. Could be base64, local file or +/// remote URL +/// @param tensorDims The dimensions of the result tensor. The two last +/// dimensions are taken as the image resolution. +/// @param maintainAspectRatio If set to true the image will be resized to +/// maintain the original aspect ratio. The rest of the tensor will be filled +/// padding. /// @return Returns a tensor pointer and the original size of the image. std::pair readImageToTensor(const std::string &path, - const std::vector &tensorDims); - -} // namespace rnexecutorch::imageprocessing \ No newline at end of file + const std::vector &tensorDims, + bool maintainAspectRatio = false); +} // namespace rnexecutorch::imageprocessing diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp index 01a3025779..6e3d9739c5 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.cpp @@ -25,8 +25,7 @@ void normalize(std::span span) { sum += val * val; } - // Early return if all values are 0 - if (sum == 0.0f) { + if (isClose(sum, 0.0f)) { return; } @@ -73,4 +72,11 @@ std::vector meanPooling(std::span modelOutput, return result; } +template bool isClose(T a, T b, T atol) { + return std::abs(a - b) <= atol; +} + +template bool isClose(float, float, float); +template bool isClose(double, double, double); + } // namespace rnexecutorch::numerical \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h index b83495f741..77a13f44fa 100644 --- a/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h +++ b/packages/react-native-executorch/common/rnexecutorch/data_processing/Numerical.h @@ -10,4 +10,14 @@ void normalize(std::vector &v); void normalize(std::span span); std::vector meanPooling(std::span modelOutput, std::span attnMask); +/** + * @brief Checks if two floating-point numbers are considered equal. + */ +template +bool isClose(T a, T b, + T atol = std::numeric_limits::epsilon() * static_cast(10)); + +extern template bool isClose(float, float, float); +extern template bool isClose(double, double, double); + } // namespace rnexecutorch::numerical \ No newline at end of file diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h index 78d037f06f..defb24e3f9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h @@ -16,6 +16,7 @@ #include #include #include +#include namespace rnexecutorch::jsiconversion { @@ -95,7 +96,6 @@ inline JSTensorViewIn getValue(const jsi::Value &val, if (dataObj.isArrayBuffer(runtime)) { jsi::ArrayBuffer arrayBuffer = dataObj.getArrayBuffer(runtime); tensorView.dataPtr = arrayBuffer.data(runtime); - } else { // Handle typed arrays (Float32Array, Int32Array, etc.) const bool isValidTypedArray = dataObj.hasProperty(runtime, "buffer") && @@ -386,4 +386,32 @@ inline jsi::Value getJsiValue(const std::vector &detections, return array; } +inline jsi::Value getJsiValue(const std::vector &detections, + jsi::Runtime &runtime) { + auto jsiDetections = jsi::Array(runtime, detections.size()); + for (size_t i = 0; i < detections.size(); ++i) { + const auto &detection = detections[i]; + + auto jsiDetectionObject = jsi::Object(runtime); + + auto jsiBboxArray = jsi::Array(runtime, 4); +#pragma unroll + for (size_t j = 0; j < 4u; ++j) { + auto jsiPointObject = jsi::Object(runtime); + jsiPointObject.setProperty(runtime, "x", detection.bbox[j].x); + jsiPointObject.setProperty(runtime, "y", detection.bbox[j].y); + jsiBboxArray.setValueAtIndex(runtime, j, jsiPointObject); + } + + jsiDetectionObject.setProperty(runtime, "bbox", jsiBboxArray); + jsiDetectionObject.setProperty( + runtime, "text", jsi::String::createFromUtf8(runtime, detection.text)); + jsiDetectionObject.setProperty(runtime, "score", detection.score); + + jsiDetections.setValueAtIndex(runtime, i, jsiDetectionObject); + } + + return jsiDetections; +} + } // namespace rnexecutorch::jsiconversion diff --git a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h index d4fb4205e0..eb4e426149 100644 --- a/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h +++ b/packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h @@ -16,6 +16,8 @@ #include #include #include +#include +#include namespace rnexecutorch { @@ -90,6 +92,16 @@ template class ModelHostObject : public JsiHostObject { addFunctions( JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); } + + if constexpr (meta::SameAs) { + addFunctions( + JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); + } + + if constexpr (meta::SameAs) { + addFunctions( + JSI_EXPORT_FUNCTION(ModelHostObject, unload, "unload")); + } } // A generic host function that runs synchronously, works analogously to the diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp index d734fca971..4edd483134 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp @@ -29,7 +29,7 @@ BaseModel::BaseModel(const std::string &modelSource, } std::vector BaseModel::getInputShape(std::string method_name, - int index) { + int32_t index) { if (!module_) { throw std::runtime_error("Model not loaded: Cannot get input shape"); } @@ -166,9 +166,11 @@ BaseModel::execute(const std::string &methodName, return module_->execute(methodName, input_value); } -std::size_t BaseModel::getMemoryLowerBound() { return memorySizeLowerBound; } +std::size_t BaseModel::getMemoryLowerBound() const noexcept { + return memorySizeLowerBound; +} -void BaseModel::unload() { module_.reset(nullptr); } +void BaseModel::unload() noexcept { module_.reset(nullptr); } std::vector BaseModel::getTensorShape(const executorch::aten::Tensor &tensor) { diff --git a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h index 3f8b106c38..a6463b3ef9 100644 --- a/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h +++ b/packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h @@ -17,9 +17,9 @@ class BaseModel { public: BaseModel(const std::string &modelSource, std::shared_ptr callInvoker); - std::size_t getMemoryLowerBound(); - void unload(); - std::vector getInputShape(std::string method_name, int index); + std::size_t getMemoryLowerBound() const noexcept; + void unload() noexcept; + std::vector getInputShape(std::string method_name, int32_t index); std::vector> getAllInputShapes(std::string methodName = "forward"); std::vector diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.cpp new file mode 100644 index 0000000000..4e78ddabc7 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.cpp @@ -0,0 +1,88 @@ +#include "CTCLabelConverter.h" +#include +#include + +namespace rnexecutorch::ocr { +CTCLabelConverter::CTCLabelConverter(const std::string &characters) + : ignoreIdx(0), + character({"[blank]"}) // blank character is ignored character (index 0). +{ + for (size_t i = 0; i < characters.length();) { + size_t char_len = 0; + unsigned char first_byte = characters[i]; + + if ((first_byte & 0x80) == 0) { // 0xxxxxxx -> 1-byte character + char_len = 1; + } else if ((first_byte & 0xE0) == 0xC0) { // 110xxxxx -> 2-byte character + char_len = 2; + } else if ((first_byte & 0xF0) == 0xE0) { // 1110xxxx -> 3-byte character + char_len = 3; + } else if ((first_byte & 0xF8) == 0xF0) { // 11110xxx -> 4-byte character + char_len = 4; + } else { + // Invalid UTF-8 start byte, treat as a single byte character to avoid + // infinite loop + char_len = 1; + } + + // Ensure we don't read past the end of the string + if (i + char_len <= characters.length()) { + character.push_back(characters.substr(i, char_len)); + } + i += char_len; + } +} + +std::vector +CTCLabelConverter::decodeGreedy(const std::vector &textIndex, + size_t length) { + /* + The current strategy used for decoding is greedy approach + which iterates through the list of indices and process + each index using following steps: + 1. Ignore if idx == 0 + 2. Ignore if idx is the same as last idx + 3. decode idx -> char and append it to returned text. + + Note that ignoring repeated indices, does not mean decoding + won't handle repeated letters in a word, since in most cases + actual chars are already seperated by blank tokens. + */ + std::vector texts; + size_t index = 0; + + while (index < textIndex.size()) { + size_t segmentLength = std::min(length, textIndex.size() - index); + + std::vector subArray(textIndex.begin() + index, + textIndex.begin() + index + segmentLength); + + std::string text; + + if (!subArray.empty()) { + std::optional lastChar; + for (int32_t currentChar : subArray) { + bool isRepeated = + lastChar.has_value() && lastChar.value() == currentChar; + bool isIgnored = currentChar == ignoreIdx; + lastChar = currentChar; + + if (currentChar >= 0 && + currentChar < static_cast(character.size()) && + !isRepeated && !isIgnored) { + text += character[currentChar]; + } + } + } + + texts.push_back(std::move(text)); + index += segmentLength; + + if (segmentLength < length) { + break; + } + } + + return texts; +} +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.h new file mode 100644 index 0000000000..06eabea403 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/CTCLabelConverter.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + +namespace rnexecutorch::ocr { +/* + CTC (Connectionist Temporal Classification) Label Converter + is used for decoding the returned list of indices by Recognizer into + actual characters. + For each Recognizer there is an 1:1 correspondence between + an index and a character. CTC Label Converter operates on this + mapping. Symbol corresponding to the first index is a [blank] + character, meaning "no character to decode here". + The decoder ignores [blank] char. +*/ + +class CTCLabelConverter final { +public: + explicit CTCLabelConverter(const std::string &characters); + + std::vector decodeGreedy(const std::vector &textIndex, + size_t length); + +private: + std::vector character; + int32_t ignoreIdx; +}; +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h new file mode 100644 index 0000000000..d1934058b3 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Constants.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +namespace rnexecutorch::ocr { + +inline constexpr float textThreshold = 0.4; +inline constexpr float textThresholdVertical = 0.3; +inline constexpr float linkThreshold = 0.4; +inline constexpr float lowTextThreshold = 0.7; +inline constexpr float centerThreshold = 0.5; +inline constexpr float distanceThreshold = 2.0; +inline constexpr float heightThreshold = 2.0; +inline constexpr float singleCharacterCenterThreshold = 0.3; +inline constexpr float lowConfidenceThreshold = 0.3; +inline constexpr float adjustContrast = 0.2; +inline constexpr int32_t minSideThreshold = 15; +inline constexpr int32_t maxSideThreshold = 30; +inline constexpr int32_t recognizerHeight = 64; +inline constexpr int32_t largeRecognizerWidth = 512; +inline constexpr int32_t mediumRecognizerWidth = 256; +inline constexpr int32_t smallRecognizerWidth = 128; +inline constexpr int32_t smallVerticalRecognizerWidth = 64; +inline constexpr int32_t maxWidth = + largeRecognizerWidth + (largeRecognizerWidth * 0.15); +inline constexpr int32_t minSize = 20; +inline constexpr int32_t singleCharacterMinSize = 70; +inline constexpr int32_t recognizerImageSize = 1280; +inline constexpr int32_t verticalLineThreshold = 20; + +inline const cv::Scalar mean(0.485, 0.456, 0.406); +inline const cv::Scalar variance(0.229, 0.224, 0.225); + +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp new file mode 100644 index 0000000000..a68d29b9c7 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.cpp @@ -0,0 +1,102 @@ +#include "Detector.h" +#include +#include +#include + +namespace rnexecutorch { +Detector::Detector(const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + auto inputShapes = getAllInputShapes(); + if (inputShapes.empty()) { + throw std::runtime_error( + "Detector model seems to not take any input tensors."); + } + std::vector modelInputShape = inputShapes[0]; + if (modelInputShape.size() < 2) { + throw std::runtime_error("Unexpected detector model input size, expected " + "at least 2 dimensions but got: " + + std::to_string(modelInputShape.size()) + "."); + } + modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); +} + +cv::Size Detector::getModelImageSize() const noexcept { return modelImageSize; } + +std::vector Detector::generate(const cv::Mat &inputImage) { + /* + Detector as an input accepts tensor with a shape of [1, 3, H, H]. + where H is a constant for model. In our supported models it is currently + either H=800 or H=1280. + Due to big influence of resize to quality of recognition the image preserves + original aspect ratio and the missing parts are filled with padding. + */ + auto inputShapes = getAllInputShapes(); + cv::Mat resizedInputImage = + imageprocessing::resizePadded(inputImage, getModelImageSize()); + TensorPtr inputTensor = imageprocessing::getTensorFromMatrix( + inputShapes[0], resizedInputImage, ocr::mean, ocr::variance); + auto forwardResult = BaseModel::forward(inputTensor); + if (!forwardResult.ok()) { + throw std::runtime_error( + "Failed to forward, error: " + + std::to_string(static_cast(forwardResult.error()))); + } + + return postprocess(forwardResult->at(0).toTensor()); +} + +std::vector +Detector::postprocess(const Tensor &tensor) const { + /* + The output of the model consists of two matrices (heat maps): + 1. ScoreText(Score map) - The probability of a region containing character. + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to + group each character into a single instance (sequence) Both matrices are + H/2xW/2 (400x400 or 640x640). + */ + std::span tensorData(tensor.const_data_ptr(), + tensor.numel()); + /* + The output of the model is a matrix half the size of the input image + containing two channels representing the heatmaps. + */ + auto [scoreTextMat, scoreAffinityMat] = ocr::interleavedArrayToMats( + tensorData, + cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)); + + /* + Heatmaps are then converted into list of bounding boxes. + Too see how it is achieved see the description of this function in + the DetectorUtils.h source file and the implementation in the + DetectorUtils.cpp. + */ + std::vector bBoxesList = ocr::getDetBoxesFromTextMap( + scoreTextMat, scoreAffinityMat, ocr::textThreshold, ocr::linkThreshold, + ocr::lowTextThreshold); + + /* + Bounding boxes are at first corresponding to the 400x400 size or 640x640. + RecognitionHandler in the later part of processing works on images of size + 1280x1280. To match this difference we has to scale by the proper factor + (3.2 or 2.0). + */ + const float restoreRatio = + ocr::calculateRestoreRatio(scoreTextMat.rows, ocr::recognizerImageSize); + ocr::restoreBboxRatio(bBoxesList, restoreRatio); + /* + Since every bounding box is processed separately by Recognition models, we'd + like to reduce the number of boxes. Also, grouping nearby boxes means we + process many words / full line at once. It is not only faster but also easier + for Recognizer models than recognition of single characters. + */ + bBoxesList = ocr::groupTextBoxes(bBoxesList, ocr::centerThreshold, + ocr::distanceThreshold, ocr::heightThreshold, + ocr::minSideThreshold, ocr::maxSideThreshold, + ocr::maxWidth); + + return bBoxesList; +} + +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h new file mode 100644 index 0000000000..b1176dd1d7 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Detector.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include +#include + +namespace rnexecutorch { +/* + Detector is a model responsible for recognizing the areas where text is + located. It returns the list of bounding boxes. The model used as detector is + based on CRAFT (Character Region Awareness for Text Detection) paper. + https://arxiv.org/pdf/1904.01941 +*/ + +using executorch::aten::Tensor; +using executorch::extension::TensorPtr; + +class Detector final : public BaseModel { +public: + explicit Detector(const std::string &modelSource, + std::shared_ptr callInvoker); + std::vector generate(const cv::Mat &inputImage); + cv::Size getModelImageSize() const noexcept; + +private: + std::vector postprocess(const Tensor &tensor) const; + cv::Size modelImageSize; +}; +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/DetectorUtils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/DetectorUtils.cpp new file mode 100644 index 0000000000..a77647ea11 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/DetectorUtils.cpp @@ -0,0 +1,703 @@ +#include "DetectorUtils.h" +#include +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch::ocr { +std::array +cvPointsFromPoints(const std::array &points) { + std::array cvPoints; +#pragma unroll + for (std::size_t i = 0; i < cvPoints.size(); ++i) { + cvPoints[i] = cv::Point2f(points[i].x, points[i].y); + } + return cvPoints; +} + +std::array pointsFromCvPoints(cv::Point2f cvPoints[4]) { + std::array points; +#pragma unroll + for (std::size_t i = 0; i < points.size(); ++i) { + points[i] = {.x = cvPoints[i].x, .y = cvPoints[i].y}; + } + return points; +} + +std::pair interleavedArrayToMats(std::span data, + cv::Size size) { + cv::Mat mat1 = cv::Mat(size, CV_32F); + cv::Mat mat2 = cv::Mat(size, CV_32F); + + for (std::size_t i = 0; i < data.size(); i++) { + const float value = data[i]; + const int32_t x = (i / 2) % size.width; + const int32_t y = (i / 2) / size.width; + + if (i % 2 == 0) { + mat1.at(y, x) = value; + } else { + mat2.at(y, x) = value; + } + } + return {mat1, mat2}; +} + +// Create a segmentation map for the current component. +// Background is 0, (black), foreground is 255 (white) +cv::Mat createSegmentMap(const cv::Mat &mask, cv::Size mapSize, + const int32_t segmentColor = 255) { + cv::Mat segMap = cv::Mat::zeros(mapSize, CV_8U); + segMap.setTo(segmentColor, mask); + return segMap; +} + +void morphologicalOperations( + const cv::Mat &segMap, const cv::Mat &stats, int32_t i, int32_t area, + int32_t imgW, int32_t imgH, + int32_t iterations = 1, // iterations number of times dilation is applied. + cv::Size anchor = + cv::Point(-1, -1) // anchor position of the anchor within the element; + // default means that the anchor is at the center. +) { + const int32_t x = stats.at(i, cv::CC_STAT_LEFT); + const int32_t y = stats.at(i, cv::CC_STAT_TOP); + const int32_t w = stats.at(i, cv::CC_STAT_WIDTH); + const int32_t h = stats.at(i, cv::CC_STAT_HEIGHT); + + // Dynamically calculate dilation radius to expand the bounding box slightly + constexpr int32_t evenMultiplyCoeff = 2; // ensure that dilationRadius is even + const int32_t dilationRadius = static_cast( + std::sqrt(static_cast(area) / std::max(w, h)) * + evenMultiplyCoeff); + const int32_t sx = std::max(x - dilationRadius, 0); + const int32_t ex = std::min(x + w + dilationRadius, imgW); + const int32_t sy = std::max(y - dilationRadius, 0); + const int32_t ey = std::min(y + h + dilationRadius, imgH); + + // Define a region of interest (ROI) and dilate it + cv::Rect roi(sx, sy, ex - sx, ey - sy); + // Morphological kernels require minimum size of 1x1 (no-op) plus dilation + // radius + const int32_t morphologicalKernelSize = + 1 + dilationRadius; // Ensures valid odd-sized kernel, + // notice the fact that dilationRadius is always even. + cv::Mat kernel = cv::getStructuringElement( + cv::MORPH_RECT, + cv::Size(morphologicalKernelSize, morphologicalKernelSize)); + cv::Mat roiSegMap = segMap(roi); + cv::dilate(roiSegMap, roiSegMap, kernel, anchor, iterations); +} + +DetectorBBox +extractMinAreaBBoxFromContour(const std::vector contour) { + cv::RotatedRect minRect = cv::minAreaRect(contour); + + std::array vertices; + minRect.points(vertices.data()); + + std::array points = pointsFromCvPoints(vertices.data()); + return {.bbox = points, .angle = minRect.angle}; +} + +void getBoxFromContour(cv::Mat &segMap, + std::vector &detectedBoxes) { + std::vector> contours; + cv::findContours(segMap, contours, cv::RETR_EXTERNAL, + cv::CHAIN_APPROX_SIMPLE); + if (!contours.empty()) { + detectedBoxes.emplace_back(extractMinAreaBBoxFromContour(contours[0])); + } +} + +// Function for processing single component. It is shared between the +// VerticalOCR and standard OCR. param isVertical specifies which OCR uses it. +// param lowTextThreshold is used only by standard OCR. +void processComponent(const cv::Mat &textMap, const cv::Mat &labels, + const cv::Mat &stats, int32_t i, int32_t imgW, + int32_t imgH, std::vector &detectedBoxes, + bool isVertical, int32_t minimalAreaThreshold, + int32_t dilationIter, float lowTextThreshold = 0.0) { + const int32_t area = stats.at(i, cv::CC_STAT_AREA); + // Skip small components as they are likely to be just noise + if (area < minimalAreaThreshold) { + return; + } + + cv::Mat mask = (labels == i); + + if (!isVertical) { + // Skip components with low values, as they are likely to be just noise + double maxVal; + cv::minMaxLoc(textMap, nullptr, &maxVal, nullptr, nullptr, mask); + if (maxVal < lowTextThreshold) { + return; + } + } + + cv::Mat segMap = createSegmentMap(mask, textMap.size()); + + // Perform morphological operations on the segment map. + // mostly includes the dilation of the region of interest + // to esnure the box captures the whole area + morphologicalOperations(segMap, stats, i, area, imgW, imgH, dilationIter); + + // Find the minimum area rotated rectangle around the contour + // and add it to the box list. + getBoxFromContour(segMap, detectedBoxes); +} + +std::vector getDetBoxesFromTextMap(cv::Mat &textMap, + cv::Mat &affinityMap, + float textThreshold, + float linkThreshold, + float lowTextThreshold) { + // Ensure input mats are of the correct type for processing + CV_Assert(textMap.type() == CV_32F && affinityMap.type() == CV_32F); + + const int32_t imgH = textMap.rows; + const int32_t imgW = textMap.cols; + cv::Mat textScore; + cv::Mat affinityScore; + + // 1. Based on maps and threshold values create binary masks + constexpr double maxValBinaryMask = 1.0; + cv::threshold(textMap, textScore, textThreshold, maxValBinaryMask, + cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, maxValBinaryMask, + cv::THRESH_BINARY); + + // 2. Merge two maps into one using logical OR + cv::Mat textScoreComb = textScore + affinityScore; + constexpr double threshVal = 0.0; + cv::threshold(textScoreComb, textScoreComb, threshVal, maxValBinaryMask, + cv::THRESH_BINARY); + cv::Mat binaryMat; + textScoreComb.convertTo(binaryMat, CV_8UC1); + + // 3. Find connected components to identify each box + cv::Mat labels, stats, centroids; + constexpr int32_t connectivityType = 4; + const int32_t nLabels = cv::connectedComponentsWithStats( + binaryMat, labels, stats, centroids, connectivityType); + + std::vector detectedBoxes; + detectedBoxes.reserve(nLabels); // Pre-allocate memory + + // number of dilation iterations performed in some + // morphological operations on a component later on. + constexpr int32_t dilationIter = 1; + // minimal accepted area of component + constexpr int32_t minimalAreaThreshold = 10; + + // 4. Process each component; omit component 0 as it is background + for (int32_t i = 1; i < nLabels; i++) { + processComponent(textMap, labels, stats, i, imgW, imgH, detectedBoxes, + false, minimalAreaThreshold, dilationIter, + lowTextThreshold); + } + + return detectedBoxes; +} + +std::vector +getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, + float textThreshold, float linkThreshold, + bool independentCharacters) { + // Ensure input mats are of the correct type for processing + CV_Assert(textMap.type() == CV_32F && affinityMap.type() == CV_32F); + + const int32_t imgH = textMap.rows; + const int32_t imgW = textMap.cols; + cv::Mat textScore; + cv::Mat affinityScore; + + // 1. Threshold text and affinity maps to create binary masks + constexpr double maxValBinaryMask = 1.0; + cv::threshold(textMap, textScore, textThreshold, maxValBinaryMask, + cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, maxValBinaryMask, + cv::THRESH_BINARY); + + // Prepare values for morphological operations + const auto kSize = cv::Size(3, 3); // size of the structuring element + cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, kSize); + + // iterations number of times erosion is applied. + constexpr int32_t erosionIterations = 1; + + // iterations number of times dilation is applied. + int32_t dilationIterations; + const auto anchor = + cv::Point(-1, -1); // anchor position of the anchor within the element; + // default value (-1, -1) + // means that the anchor is at the element center + + // 2. Combine maps based on whether we are detecting words or single + // characters + // For single characters, subtract affinity to separate adjacent chars, + // otherwise add affinity to link characters together + cv::Mat textScoreComb = independentCharacters ? textScore - affinityScore + : textScore + affinityScore; + // Clamp values to be >= 0 + cv::threshold(textScoreComb, textScoreComb, 0.0, 1.0, cv::THRESH_TOZERO); + // Clamp values to be <= 1 + cv::threshold(textScoreComb, textScoreComb, 1.0, 1.0, cv::THRESH_TRUNC); + + // Perform morphological operations to refine character regions + if (independentCharacters) { + dilationIterations = 4; + cv::erode(textScoreComb, textScoreComb, kernel, anchor, erosionIterations); + } else { + dilationIterations = 2; + } + cv::dilate(textScoreComb, textScoreComb, kernel, anchor, dilationIterations); + + // 3. Find connected components to identify each character/word + cv::Mat binaryMat; + textScoreComb.convertTo(binaryMat, CV_8UC1); + + cv::Mat labels, stats, centroids; + constexpr int32_t connectivityType = 4; + const int32_t nLabels = cv::connectedComponentsWithStats( + binaryMat, labels, stats, centroids, connectivityType); + + std::vector detectedBoxes; + detectedBoxes.reserve(nLabels); + + // number of dilation iterations performed in some + // morphological operations on a component later on. + constexpr int32_t dilationIter = 2; + // minimal accepted area of component + constexpr int32_t minimalAreaThreshold = 20; + + // 4. Process each component; omit component 0 as it is background + for (int32_t i = 1; i < nLabels; ++i) { + const int32_t width = stats.at(i, cv::CC_STAT_WIDTH); + const int32_t height = stats.at(i, cv::CC_STAT_HEIGHT); + // For vertical text (not single chars), height should be greater than width + if (!independentCharacters && height < width) { + continue; + } + processComponent(textMap, labels, stats, i, imgW, imgH, detectedBoxes, true, + minimalAreaThreshold, dilationIter); + } + + return detectedBoxes; +} + +float calculateRestoreRatio(int32_t currentSize, int32_t desiredSize) { + return desiredSize / static_cast(currentSize); +} + +void restoreBboxRatio(std::vector &boxes, float restoreRatio) { + for (auto &box : boxes) { + for (auto &point : box.bbox) { + point.x *= restoreRatio; + point.y *= restoreRatio; + } + } +} + +float distanceFromPoint(const Point &p1, const Point &p2) { + const float xDist = p2.x - p1.x; + const float yDist = p2.y - p1.y; + return std::hypot(xDist, yDist); +} + +float normalizeAngle(float angle) { + return (angle > 45.0f) ? (angle - 90.0f) : angle; +} + +Point midpointBetweenPoint(const Point &p1, const Point &p2) { + return {.x = std::midpoint(p1.x, p2.x), .y = std::midpoint(p1.y, p2.y)}; +} + +Point centerOfBox(const std::array &box) { + return midpointBetweenPoint(box[0], box[2]); +} + +// function for both; finding maximal side length and minimal side length +template +float findExtremeSideLength(const std::array &points, Compare comp) { + float extremeLength = distanceFromPoint(points[0], points[1]); + +#pragma unroll + for (std::size_t i = 1; i < points.size(); i++) { + const auto ¤tPoint = points[i]; + const auto &nextPoint = points[(i + 1) % points.size()]; + const float sideLength = distanceFromPoint(currentPoint, nextPoint); + + if (comp(sideLength, extremeLength)) { + extremeLength = sideLength; + } + } + + return extremeLength; +} + +float minSideLength(const std::array &points) { + return findExtremeSideLength(points, std::less{}); +} + +float maxSideLength(const std::array &points) { + return findExtremeSideLength(points, std::greater{}); +} + +/** + * This method calculates the distances between each sequential pair of points + * in a presumed quadrilateral, identifies the two shortest sides, and fits a + * linear model to the midpoints of these sides. It also evaluates whether the + * resulting line should be considered vertical based on a predefined threshold + * for the x-coordinate differences. + * + * If the line is vertical it is fitted as a function of x = my + c, otherwise + * as y = mx + c. + * + * @return A tuple with 2 floats and a bool, where: + * - the first float represents the slope (m) of the line. + * - the second float represents the line's intercept (c) with y-axis. + * - a bool indicating whether the line is + * considered vertical. + */ +std::tuple +fitLineToShortestSides(const std::array &points) { + std::array, 4> sides; + std::array midpoints; +#pragma unroll + for (std::size_t i = 0; i < midpoints.size(); i++) { + const auto p1 = points[i]; + const auto p2 = points[(i + 1) % midpoints.size()]; + + const float sideLength = distanceFromPoint(p1, p2); + sides[i] = std::make_pair(sideLength, i); + midpoints[i] = midpointBetweenPoint(p1, p2); + } + + // Sort the sides by length ascending + std::ranges::sort(sides); + + const Point midpoint1 = midpoints[sides[0].second]; + const Point midpoint2 = midpoints[sides[1].second]; + const float dx = std::fabs(midpoint2.x - midpoint1.x); + + float m, c; + bool isVertical; + + std::array cvMidPoints = { + cv::Point2f(midpoint1.x, midpoint1.y), + cv::Point2f(midpoint2.x, midpoint2.y)}; + cv::Vec4f line; + // parameteres for fitLine calculation: + constexpr int32_t numericalParameter = + 0; // important only for some types of distances, O means an optimal value + // is chosen + constexpr double accuracy = + 0.01; // sufficient accuracy. Value proposed by OPENCV + + isVertical = dx < verticalLineThreshold; + if (isVertical) { + for (auto &pt : cvMidPoints) { + std::swap(pt.x, pt.y); + } + } + cv::fitLine(cvMidPoints, line, cv::DIST_L2, numericalParameter, accuracy, + accuracy); + m = line[1] / line[0]; + c = line[3] - m * line[2]; + return {m, c, isVertical}; +} + +std::array rotateBox(const std::array &box, float angle) { + const Point center = centerOfBox(box); + + const float radians = angle * M_PI / 180.0f; + + std::array rotatedPoints; + for (std::size_t i = 0; i < box.size(); ++i) { + const Point &point = box[i]; + const float translatedX = point.x - center.x; + const float translatedY = point.y - center.y; + + const float rotatedX = + translatedX * std::cos(radians) - translatedY * std::sin(radians); + const float rotatedY = + translatedX * std::sin(radians) + translatedY * std::cos(radians); + + rotatedPoints[i] = {.x = rotatedX + center.x, .y = rotatedY + center.y}; + } + + return rotatedPoints; +} + +float calculateMinimalDistanceBetweenBox(const std::array &box1, + const std::array &box2) { + float minDistance = std::numeric_limits::max(); + for (const Point &corner1 : box1) { + for (const Point &corner2 : box2) { + const float distance = distanceFromPoint(corner1, corner2); + minDistance = std::min(distance, minDistance); + } + } + return minDistance; +} + +/** + * Orders a set of 4 points in a clockwise direction starting with the top-left + * point. + * + * Process: + * 1. It iterates through each Point. + * 2. For each point, it calculates the sum (x + y) and difference (y - x) of + * the coordinates. + * 3. Points are classified into: + * - Top-left: Minimum sum. + * - Bottom-right: Maximum sum. + * - Top-right: Minimum difference. + * - Bottom-left: Maximum difference. + * 4. The points are ordered starting from the top-left in a clockwise manner: + * top-left, top-right, bottom-right, bottom-left. + */ +std::array orderPointsClockwise(const std::array &points) { + Point topLeft, topRight, bottomRight, bottomLeft; + float minSum = std::numeric_limits::max(); + float maxSum = std::numeric_limits::lowest(); + float minDiff = std::numeric_limits::max(); + float maxDiff = std::numeric_limits::lowest(); + + for (const auto &pt : points) { + const float sum = pt.x + pt.y; + const float diff = pt.y - pt.x; + + if (sum < minSum) { + minSum = sum; + topLeft = pt; + } + if (sum > maxSum) { + maxSum = sum; + bottomRight = pt; + } + if (diff < minDiff) { + minDiff = diff; + topRight = pt; + } + if (diff > maxDiff) { + maxDiff = diff; + bottomLeft = pt; + } + } + + return {topLeft, topRight, bottomRight, bottomLeft}; +} + +std::array mergeRotatedBoxes(std::array &box1, + std::array &box2) { + box1 = orderPointsClockwise(box1); + box2 = orderPointsClockwise(box2); + + auto points1 = cvPointsFromPoints(box1); + auto points2 = cvPointsFromPoints(box2); + + std::array allPoints; + std::copy(points1.begin(), points1.end(), allPoints.begin()); + std::copy(points2.begin(), points2.end(), allPoints.begin() + points1.size()); + + std::vector hullIndices; + cv::convexHull(allPoints, hullIndices, false); + + std::vector hullPoints; + for (int32_t idx : hullIndices) { + hullPoints.push_back(allPoints[idx]); + } + + cv::RotatedRect minAreaRect = cv::minAreaRect(hullPoints); + + std::array rectPoints; + minAreaRect.points(rectPoints.data()); + + return pointsFromCvPoints(rectPoints.data()); +} + +/** + * This method assesses each box from a provided vector, checks its center + * against the center of a "current box", and evaluates its alignment with a + * specified line equation. The function specifically searches for the box whose + * center is closest to the current box that has not been ignored, and fits + * within a defined distance from the line. + * + * @param boxes A vector of DetectorBBoxes + * @param ignoredIdxs A set of indices of boxes to ignore in the evaluation. + * @param currentBox Array of points encapsulating representing the current box + * to compare against. + * @param isVertical A boolean indicating if the line to compare distance to is + * vertical. + * @param m The slope (gradient) of the line against which the box's alignment + * is checked. + * @param c The y-intercept of the line equation y = mx + c. + * @param centerThreshold A multiplier to determine the threshold for the + * distance between the box's center and the line. + * + * @return A an optional pair containing: + * - the index of the found box in the original vector. + * - the length of the shortest side of the found box. + * If no suitable box is found the optional is null. + */ +std::optional> +findClosestBox(const std::vector &boxes, + const std::unordered_set &ignoredIdxs, + const std::array ¤tBox, bool isVertical, float m, + float c, float centerThreshold) { + float smallestDistance = std::numeric_limits::max(); + ssize_t idx = -1; + float boxHeight = 0.0f; + const Point centerOfCurrentBox = centerOfBox(currentBox); + + for (std::size_t i = 0; i < boxes.size(); i++) { + if (ignoredIdxs.contains(i)) { + continue; + } + std::array bbox = boxes[i].bbox; + const Point centerOfProcessedBox = centerOfBox(bbox); + const float distanceBetweenCenters = + distanceFromPoint(centerOfCurrentBox, centerOfProcessedBox); + + if (distanceBetweenCenters >= smallestDistance) { + continue; + } + + boxHeight = minSideLength(bbox); + + const float lineDistance = + isVertical ? std::fabs(centerOfProcessedBox.x - + (m * centerOfProcessedBox.y + c)) + : std::fabs(centerOfProcessedBox.y - + (m * centerOfProcessedBox.x + c)); + + if (lineDistance < boxHeight * centerThreshold) { + idx = i; + smallestDistance = distanceBetweenCenters; + } + } + + return idx != -1 ? std::optional(std::make_pair(idx, boxHeight)) + : std::nullopt; +} + +/** + * Filters out boxes that are smaller than the specified thresholds. + * A box is kept only if: + * - Its shorter side is **greater than** `minSideThreshold`, **and** + * - Its longer side is **greater than** `maxSideThreshold`. + * Otherwise, the box is excluded from the result. + */ +std::vector +removeSmallBoxesFromArray(const std::vector &boxes, + float minSideThreshold, float maxSideThreshold) { + std::vector filteredBoxes; + + for (const auto &box : boxes) { + const float maxSide = maxSideLength(box.bbox); + const float minSide = minSideLength(box.bbox); + if (minSide > minSideThreshold && maxSide > maxSideThreshold) { + filteredBoxes.push_back(box); + } + } + + return filteredBoxes; +} + +static float minimumYFromBox(const std::array &box) { + return std::ranges::min_element(box, + [](Point a, Point b) { return a.y < b.y; }) + ->y; +} + +std::vector +groupTextBoxes(std::vector &boxes, float centerThreshold, + float distanceThreshold, float heightThreshold, + int32_t minSideThreshold, int32_t maxSideThreshold, + int32_t maxWidth) { + // Sort boxes descending by maximum side length + std::ranges::sort(boxes, + [](const DetectorBBox &lhs, const DetectorBBox &rhs) { + return maxSideLength(lhs.bbox) > maxSideLength(rhs.bbox); + }); + + std::vector mergedVec; + float lineAngle; + std::unordered_set ignoredIdxs; + while (!boxes.empty()) { + auto currentBox = boxes[0]; + float normalizedAngle = normalizeAngle(currentBox.angle); + boxes.erase(boxes.begin()); + ignoredIdxs.clear(); + + while (true) { + // Find all aligned boxes and merge them until max_size is reached or no + // more boxes can be merged + auto [slope, intercept, isVertical] = + fitLineToShortestSides(currentBox.bbox); + + lineAngle = std::atan(slope) * 180.0f / M_PI; + if (isVertical) { + lineAngle = -90.0f; + } + auto closestBoxInfo = + findClosestBox(boxes, ignoredIdxs, currentBox.bbox, isVertical, slope, + intercept, centerThreshold); + if (!closestBoxInfo.has_value()) { + break; + } + const auto [candidateIdx, candidateHeight] = closestBoxInfo.value(); + DetectorBBox candidateBox = boxes[candidateIdx]; + + if ((numerical::isClose(candidateBox.angle, 90.0f) && !isVertical) || + (numerical::isClose(candidateBox.angle, 0.0f) && isVertical)) { + candidateBox.bbox = rotateBox(candidateBox.bbox, normalizedAngle); + } + + const float minDistance = calculateMinimalDistanceBetweenBox( + candidateBox.bbox, currentBox.bbox); + const float mergedHeight = minSideLength(currentBox.bbox); + if (minDistance < distanceThreshold * candidateHeight && + std::fabs(mergedHeight - candidateHeight) < + candidateHeight * heightThreshold) { + currentBox.bbox = mergeRotatedBoxes(currentBox.bbox, candidateBox.bbox); + boxes.erase(boxes.begin() + candidateIdx); + ignoredIdxs.clear(); + if (maxSideLength(currentBox.bbox) > maxWidth) { + break; + } + } else { + ignoredIdxs.insert(candidateIdx); + } + } + mergedVec.emplace_back(currentBox.bbox, lineAngle); + } + + // Remove small boxes and sort by vertical + mergedVec = + removeSmallBoxesFromArray(mergedVec, minSideThreshold, maxSideThreshold); + + std::ranges::sort(mergedVec, [](const auto &obj1, const auto &obj2) { + const auto &coords1 = obj1.bbox; + const auto &coords2 = obj2.bbox; + const float minY1 = minimumYFromBox(coords1); + const float minY2 = minimumYFromBox(coords2); + return minY1 < minY2; + }); + + std::vector orderedSortedBoxes; + orderedSortedBoxes.reserve(mergedVec.size()); + for (DetectorBBox bbox : mergedVec) { + bbox.bbox = orderPointsClockwise(bbox.bbox); + orderedSortedBoxes.push_back(std::move(bbox)); + } + + return orderedSortedBoxes; +} + +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/DetectorUtils.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/DetectorUtils.h new file mode 100644 index 0000000000..7aa81681a6 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/DetectorUtils.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include + +namespace rnexecutorch::ocr { +std::pair interleavedArrayToMats(std::span data, + cv::Size size); +/** + * This method applies a series of image processing operations to identify + * likely areas of text in the textMap and return the bounding boxes for single + * words. + * + * @param textMap A cv::Mat representing a heat map of the characters of text + * being present in an image. + * @param affinityMap A cv::Mat representing a heat map of the affinity between + * characters. + * @param textThreshold A float representing the threshold for the text map. + * @param linkThreshold A float representing the threshold for the affinity + * map. + * @param lowTextThreshold A float representing the low text. + * + * @return A vector containing DetectorBBox bounding boxes. Each DetectorBBox + * includes: + * - "bbox": an array of Point values representing the vertices of the + * detected text box. + * - "angle": a float representing the rotation angle of the box. + */ +std::vector getDetBoxesFromTextMap(cv::Mat &textMap, + cv::Mat &affinityMap, + float textThreshold, + float linkThreshold, + float lowTextThreshold); +std::vector +getDetBoxesFromTextMapVertical(cv::Mat &textMap, cv::Mat &affinityMap, + float textThreshold, float linkThreshold, + bool independentCharacters); + +float calculateRestoreRatio(int32_t currentSize, int32_t desiredSize); + +void restoreBboxRatio(std::vector &boxes, float restoreRatio); +/** + * This method processes a vector of DetectorBBox bounding boxes, each + * containing details about individual text boxes, and attempts to group and + * merge these boxes based on specified criteria including proximity, alignment, + * and size thresholds. It prioritizes merging of boxes that are aligned closely + * in angle, are near each other, and whose sizes are compatible based on the + * given thresholds. + * + * @param boxes A vector of DetectorBBoxes where each bounding box + * represents a text box. + * @param centerThreshold A float representing the threshold for considering + * the distance between center and fitted line. + * @param distanceThreshold A float that defines the maximum allowed distance + * between boxes for them to be considered for merging. + * @param heightThreshold A float representing the maximum allowed difference + * in height between boxes for merging. + * @param minSideThreshold An int that defines the minimum dimension threshold + * to filter out small boxes after grouping. + * @param maxSideThreshold An int that specifies the maximum dimension threshold + * for filtering boxes post-grouping. + * @param maxWidth An int that represents the maximum width allowable for a + * merged box. + * + * @return A vector of DetectorBBoxes representing the merged boxes. + * + * Processing Steps: + * 1. Sort initial boxes based on their maximum side length. + * 2. Sequentially merge boxes considering alignment, proximity, and size + * compatibility. + * 3. Post-processing to remove any boxes that are too small. + * 4. Sort the final array of boxes by their vertical positions. + */ +std::vector +groupTextBoxes(std::vector &boxes, float centerThreshold, + float distanceThreshold, float heightThreshold, + int32_t minSideThreshold, int32_t maxSideThreshold, + int32_t maxWidth); +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp new file mode 100644 index 0000000000..9d23587539 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.cpp @@ -0,0 +1,52 @@ +#include "OCR.h" +#include +#include + +namespace rnexecutorch { +OCR::OCR(const std::string &detectorSource, + const std::string &recognizerSourceLarge, + const std::string &recognizerSourceMedium, + const std::string &recognizerSourceSmall, std::string symbols, + std::shared_ptr callInvoker) + : detector(detectorSource, callInvoker), + recognitionHandler(recognizerSourceLarge, recognizerSourceMedium, + recognizerSourceSmall, symbols, callInvoker) {} + +std::vector OCR::generate(std::string input) { + cv::Mat image = imageprocessing::readImage(input); + if (image.empty()) { + throw std::runtime_error("Failed to load image from path: " + input); + } + + /* + 1. Detection process returns the list of bounding boxes containing areas + with text. They are corresponding to the image of size 1280x1280, which + is a size later used by Recognition Handler. + */ + std::vector bboxesList = detector.generate(image); + cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); + + /* + Recognition Handler is responsible for deciding which Recognition model to + use for each box. It returns the list of tuples; each consisting of: + - recognized text + - coordinates of bounding box corresponding to the original image size + - confidence score + */ + std::vector result = recognitionHandler.recognize( + bboxesList, image, + cv::Size(ocr::recognizerImageSize, ocr::recognizerImageSize)); + + return result; +} + +std::size_t OCR::getMemoryLowerBound() const noexcept { + return detector.getMemoryLowerBound() + + recognitionHandler.getMemoryLowerBound(); +} + +void OCR::unload() noexcept { + detector.unload(); + recognitionHandler.unload(); +} +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h new file mode 100644 index 0000000000..6a9cc90b5a --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/OCR.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace rnexecutorch { +/* + The OCR consists of two phases: + 1. Detection - detecting text regions in the image, the result of this phase + is a list of bounding boxes. + 2. Recognition - recognizing the text in the bounding boxes, the result is a + list of strings and corresponding boxes & confidence scores. + + Recognition uses three models, each model is resposible for recognizing text + of different sizes (e.g. large - 512x64, medium - 256x64, small - 128x64). +*/ + +class OCR final { +public: + explicit OCR(const std::string &detectorSource, + const std::string &recognizerSourceLarge, + const std::string &recognizerSourceMedium, + const std::string &recognizerSourceSmall, std::string symbols, + std::shared_ptr callInvoker); + std::vector generate(std::string input); + std::size_t getMemoryLowerBound() const noexcept; + void unload() noexcept; + +private: + Detector detector; + RecognitionHandler recognitionHandler; +}; +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp new file mode 100644 index 0000000000..78cff346dc --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.cpp @@ -0,0 +1,107 @@ +#include "RecognitionHandler.h" +#include "RecognitionHandlerUtils.h" +#include +#include +#include + +namespace rnexecutorch { +RecognitionHandler::RecognitionHandler( + const std::string &recognizerSourceLarge, + const std::string &recognizerSourceMedium, + const std::string &recognizerSourceSmall, std::string symbols, + std::shared_ptr callInvoker) + : converter(symbols), recognizerLarge(recognizerSourceLarge, callInvoker), + recognizerMedium(recognizerSourceMedium, callInvoker), + recognizerSmall(recognizerSourceSmall, callInvoker) { + memorySizeLowerBound = recognizerSmall.getMemoryLowerBound() + + recognizerMedium.getMemoryLowerBound() + + recognizerLarge.getMemoryLowerBound(); +} + +std::pair, float> +RecognitionHandler::runModel(cv::Mat image) { + + // Note that the height of an image is always equal to 64. + if (image.cols >= ocr::largeRecognizerWidth) { + return recognizerLarge.generate(image); + } + if (image.cols >= ocr::mediumRecognizerWidth) { + return recognizerMedium.generate(image); + } + return recognizerSmall.generate(image); +} + +void RecognitionHandler::processBBox(std::vector &boxList, + ocr::DetectorBBox &box, cv::Mat &imgGray, + ocr::PaddingInfo ratioAndPadding) { + + /* + Resize the cropped image to have height = 64 (height accepted by + Recognizer). + */ + auto croppedImage = ocr::cropImage(box, imgGray, ocr::recognizerHeight); + + if (croppedImage.empty()) { + return; + } + + /* + Cropped image is resized into the closest of on of three: + 128x64, 256x64, 512x64. + */ + croppedImage = ocr::normalizeForRecognizer( + croppedImage, ocr::recognizerHeight, ocr::adjustContrast, false); + + auto [predictionIndices, confidenceScore] = this->runModel(croppedImage); + if (confidenceScore < ocr::lowConfidenceThreshold) { + cv::rotate(croppedImage, croppedImage, cv::ROTATE_180); + auto [rotatedPredictionIndices, rotatedConfidenceScore] = + runModel(croppedImage); + if (rotatedConfidenceScore > confidenceScore) { + confidenceScore = rotatedConfidenceScore; + predictionIndices = rotatedPredictionIndices; + } + } + /* + Since the boxes were corresponding to the image resized to 1280x1280, + we want to return the boxes shifted and rescaled to match the original + image dimensions. + */ + for (auto &point : box.bbox) { + point.x = (point.x - ratioAndPadding.left) * ratioAndPadding.resizeRatio; + point.y = (point.y - ratioAndPadding.top) * ratioAndPadding.resizeRatio; + } + boxList.emplace_back( + box.bbox, + converter.decodeGreedy(predictionIndices, predictionIndices.size())[0], + confidenceScore); +} + +std::vector +RecognitionHandler::recognize(std::vector bboxesList, + cv::Mat &imgGray, cv::Size desiredSize) { + /* + Recognition Handler accepts bboxesList corresponding to size + 1280x1280, which is desiredSize. + */ + ocr::PaddingInfo ratioAndPadding = + ocr::calculateResizeRatioAndPaddings(imgGray.size(), desiredSize); + imgGray = imageprocessing::resizePadded(imgGray, desiredSize); + + std::vector result = {}; + for (auto &box : bboxesList) { + processBBox(result, box, imgGray, ratioAndPadding); + } + return result; +} + +std::size_t RecognitionHandler::getMemoryLowerBound() const noexcept { + return memorySizeLowerBound; +} + +void RecognitionHandler::unload() noexcept { + recognizerSmall.unload(); + recognizerMedium.unload(); + recognizerLarge.unload(); +} +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h new file mode 100644 index 0000000000..a02e028516 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandler.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch { +/* + Recogntion Handler is responsible for: + 1. Preparing the image to be processed by Recognition Model. + 2. Deciding which Recogntion Model is used for each detected bounding box. + 3. Returning the list of tuples (box, text, confidence) to the OCR class. +*/ + +class RecognitionHandler final { +public: + explicit RecognitionHandler(const std::string &recognizerSourceLarge, + const std::string &recognizerSourceMedium, + const std::string &recognizerSourceSmall, + std::string symbols, + std::shared_ptr callInvoker); + std::vector recognize(std::vector bboxesList, + cv::Mat &imgGray, cv::Size desiredSize); + void unload() noexcept; + std::size_t getMemoryLowerBound() const noexcept; + +private: + std::pair, float> runModel(cv::Mat image); + void processBBox(std::vector &boxList, ocr::DetectorBBox &box, + cv::Mat &imgGray, ocr::PaddingInfo ratioAndPadding); + std::size_t memorySizeLowerBound{0}; + ocr::CTCLabelConverter converter; + Recognizer recognizerLarge; + Recognizer recognizerMedium; + Recognizer recognizerSmall; +}; +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandlerUtils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandlerUtils.cpp new file mode 100644 index 0000000000..3270da4175 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandlerUtils.cpp @@ -0,0 +1,153 @@ +#include "RecognitionHandlerUtils.h" +#include +#include +#include + +namespace rnexecutorch::ocr { +PaddingInfo calculateResizeRatioAndPaddings(cv::Size size, + cv::Size desiredSize) { + const auto newRatioH = static_cast(desiredSize.height) / size.height; + const auto newRatioW = static_cast(desiredSize.width) / size.width; + auto resizeRatio = std::min(newRatioH, newRatioW); + + const auto newHeight = static_cast(size.height * resizeRatio); + const auto newWidth = static_cast(size.width * resizeRatio); + + const int32_t deltaH = desiredSize.height - newHeight; + const int32_t deltaW = desiredSize.width - newWidth; + + const int32_t top = deltaH / 2; + const int32_t left = deltaW / 2; + + const auto heightRatio = static_cast(size.height) / desiredSize.height; + const auto widthRatio = static_cast(size.width) / desiredSize.width; + + resizeRatio = std::max(heightRatio, widthRatio); + return {resizeRatio, top, left}; +} + +void computeRatioAndResize(cv::Mat &img, cv::Size size, int32_t modelHeight) { + auto ratio = + static_cast(size.width) / static_cast(size.height); + cv::Size resizedSize; + if (ratio < 1.0) { + resizedSize = + cv::Size(modelHeight, static_cast(modelHeight / ratio)); + } else { + resizedSize = + cv::Size(static_cast(modelHeight * ratio), modelHeight); + } + cv::resize(img, img, resizedSize, 0.0, 0.0, cv::INTER_LANCZOS4); +} + +cv::Mat cropImage(DetectorBBox box, cv::Mat &image, int32_t modelHeight) { + // Convert custom points to cv::Point2f + std::array points; +#pragma unroll + for (std::size_t i = 0; i < points.size(); ++i) { + points[i] = cv::Point2f(box.bbox[i].x, box.bbox[i].y); + } + + cv::RotatedRect rotatedRect = cv::minAreaRect(points); + cv::Point2f rectPoints[4]; + rotatedRect.points(rectPoints); + + // Rotate the image + cv::Point2f imageCenter(image.cols / 2.0f, image.rows / 2.0f); + cv::Mat rotationMatrix = cv::getRotationMatrix2D(imageCenter, box.angle, 1.0); + cv::Mat rotatedImage; + cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), + cv::INTER_LINEAR); + + cv::Mat rectMat(4, 2, CV_32FC2); +#pragma unroll + for (int32_t i = 0; i < rectMat.rows; ++i) { + rectMat.at(i, 0) = cv::Vec2f(rectPoints[i].x, rectPoints[i].y); + } + cv::transform(rectMat, rectMat, rotationMatrix); + + std::vector transformedPoints(4); +#pragma unroll + for (std::size_t i = 0; i < transformedPoints.size(); ++i) { + cv::Vec2f point = rectMat.at(i, 0); + transformedPoints[i] = cv::Point2f(point[0], point[1]); + } + + cv::Rect boundingBox = cv::boundingRect(transformedPoints); + + cv::Rect validRegion(0, 0, rotatedImage.cols, rotatedImage.rows); + + boundingBox = boundingBox & validRegion; // OpenCV's built-in intersection + + if (boundingBox.empty()) { + return {}; + } + + cv::Mat croppedImage = rotatedImage(boundingBox).clone(); + + computeRatioAndResize(croppedImage, + cv::Size(boundingBox.width, boundingBox.height), + modelHeight); + + return croppedImage; +} + +void adjustContrastGrey(cv::Mat &img, double target) { + constexpr double minValue = 0.0; + constexpr double maxValue = 255.0; + + // calculate the brightest and the darkest point from the img + double highDouble; + double lowDouble; + cv::minMaxLoc(img, &lowDouble, &highDouble); + const auto low = static_cast(lowDouble); + const auto high = static_cast(highDouble); + + double contrast = (highDouble - lowDouble) / maxValue; + if (contrast < target) { + constexpr double maxStretchIntensity = 200.0; + constexpr int32_t minRangeClamp = 10; + // Defines how much the contrast will actually stretch. + // Formula obtained empirically. + double ratio = maxStretchIntensity / std::max(minRangeClamp, high - low); + cv::Mat tempImg; + img.convertTo(tempImg, CV_32F); + constexpr int32_t histogramShift = 25; + + tempImg -= (low - histogramShift); + tempImg *= ratio; + + cv::threshold(tempImg, tempImg, maxValue, maxValue, cv::THRESH_TRUNC); + cv::threshold(tempImg, tempImg, minValue, maxValue, cv::THRESH_TOZERO); + + tempImg.convertTo(img, CV_8U); + } +} + +int32_t getDesiredWidth(const cv::Mat &img, bool isVertical) { + + if (img.cols >= largeRecognizerWidth) { + return largeRecognizerWidth; + } + if (img.cols >= mediumRecognizerWidth) { + return mediumRecognizerWidth; + } + return isVertical ? smallVerticalRecognizerWidth : smallRecognizerWidth; +} + +cv::Mat normalizeForRecognizer(const cv::Mat &image, int32_t modelHeight, + double adjustContrast, bool isVertical) { + auto img = image.clone(); + if (adjustContrast > 0.0) { + adjustContrastGrey(img, adjustContrast); + } + + int32_t desiredWidth = getDesiredWidth(image, isVertical); + + img = imageprocessing::resizePadded(img, cv::Size(desiredWidth, modelHeight)); + img.convertTo(img, CV_32F, 1.0f / 255.0f); + img -= 0.5f; + img *= 2.0f; + return img; +} +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandlerUtils.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandlerUtils.h new file mode 100644 index 0000000000..2d15e9933d --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognitionHandlerUtils.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include + +namespace rnexecutorch::ocr { +/** + * @brief Calculates the resize ratio and padding offsets needed to fit an image + * into a target size while maintaining aspect ratio. + * @param size Original dimensions of the image. + * @param desiredSize Target size. + * @return Struct containing the scaling factor and top/left padding amounts for + * centering the image. + */ +PaddingInfo calculateResizeRatioAndPaddings(cv::Size size, + cv::Size desiredSize); +/** + * @brief Resizes an image proportionally to match a target height while + * maintaining aspect ratio. + * @param img Input/output image to resize. + * @param size Original dimensions of the image. + * @param modelHeight Target height for the output image. + */ +void computeRatioAndResize(cv::Mat &img, cv::Size size, int32_t modelHeight); +/** + * @brief Crops and aligns a rotated bounding box region from an image, then + * resizes it to target height. + * + * Handles rotated boxes by: + * 1. Calculating minimum area rectangle around detected points + * 2. Rotating the entire image to align the box horizontally + * 3. Transforming the box coordinates to match the rotated image + * 4. Cropping the aligned region + * + * Resizing: + * - Maintains original aspect ratio while scaling to specified modelHeight + * - Uses high-quality interpolation for both rotation and resizing + * + * @param box Detected bounding box with rotation angle and corner points + * @param image Source image to crop from + * @param modelHeight Target height for output (width scales proportionally) + * @return Cropped, aligned and resized image region (empty if invalid box) + */ +cv::Mat cropImage(DetectorBBox box, cv::Mat &image, int32_t modelHeight); +void adjustContrastGrey(cv::Mat &img, double target); +/** + * @brief Prepares an image for recognition models by standardizing size, + * contrast, and pixel values. + * + * Performs the following processing pipeline: + * 1. Adjusts contrast (if coefficient > 0) + * 2. Resizes to target height while: + * - Preserving aspect ratio (using padding if needed) + * - Selecting width to match one of the Recognizer accepted + * widths; (Large,Medium or Small RecognizerWidth) + * 3. Normalizes pixel values to [-1, 1] range (from [0,255] input) + * + * @param image Input image to process (any size, will be cloned) + * @param modelHeight Target output height in pixels + * @param adjustContrast Contrast adjustment coefficient (0.0 = no adjustment) + * @param isVertical Whether the image is in portrait orientation (affects width + * selection) + * + * @return Processed image with: + * - Standardized dimensions (selected width × modelHeight) + * - Adjusted contrast (if requested) + * - Normalized float32 values in [-1, 1] range + */ +cv::Mat normalizeForRecognizer(const cv::Mat &image, int32_t modelHeight, + double adjustContrast = 0.0, + bool isVertical = false); +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp new file mode 100644 index 0000000000..699332aeb2 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.cpp @@ -0,0 +1,80 @@ +#include "Recognizer.h" +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch { +Recognizer::Recognizer(const std::string &modelSource, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + auto inputShapes = getAllInputShapes(); + if (inputShapes.empty()) { + throw std::runtime_error("Recognizer model has no input tensors."); + } + std::vector modelInputShape = inputShapes[0]; + if (modelInputShape.size() < 2) { + throw std::runtime_error("Unexpected Recognizer model input shape."); + } + modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); +} + +std::pair, float> +Recognizer::generate(const cv::Mat &grayImage) { + /* + In our pipeline we use three types of Recognizer, each designated to + handle different image sizes: + - Small Recognizer - 128 x 64 + - Medium Recognizer - 256 x 64 + - Large Recognizer - 512 x 64 + The `generate` function as an argument accepts an image in grayscale + already resized to the expected size. + */ + std::vector tensorDims = getAllInputShapes()[0]; + TensorPtr inputTensor = + imageprocessing::getTensorFromMatrixGray(tensorDims, grayImage); + auto forwardResult = BaseModel::forward(inputTensor); + if (!forwardResult.ok()) { + throw std::runtime_error( + "Failed to forward in Recognizer, error: " + + std::to_string(static_cast(forwardResult.error()))); + } + + return postprocess(forwardResult->at(0).toTensor()); +} + +std::pair, float> +Recognizer::postprocess(const Tensor &tensor) const { + /* + Raw model returns a tensor with dimensions [ 1 x seqLen x alphabetSize ] + where: + + - seqLen is the length of predicted sequence. It is constant for the model. + For our models it is: + - 31 for Small Recognizer + - 63 for Medium Recognizer + - 127 for Large Recognizer + Remember that usually many tokens of predicted sequences are blank, meaning + the predicted text is not of const size. + + - alphabetSize is the length of considered alphabet. It is constant for the + model. Usually depends on language, e.g. for our models for english it is 97, + for polish it is 357 etc. + + Each value of returned tensor corresponds to character logits. + */ + const int32_t alphabetSize = tensor.size(2); + const int32_t numRows = tensor.numel() / alphabetSize; + + cv::Mat resultMat(numRows, alphabetSize, CV_32F, + tensor.mutable_data_ptr()); + + auto probabilities = ocr::softmax(resultMat); + auto [maxVal, maxIndices] = ocr::findMaxValuesIndices(probabilities); + float confidence = ocr::confidenceScore(maxVal, maxIndices); + return {maxIndices, confidence}; +} +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.h new file mode 100644 index 0000000000..50d61150c0 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Recognizer.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch { +/* + Recognizer is a model responsible for interpreting detected text regions + into characters/words. + + The model used as Recognizer is based on CRNN paper. + https://arxiv.org/pdf/1507.05717 + + It returns the list of predicted indices and a confidence value. +*/ + +using executorch::aten::Tensor; +using executorch::extension::TensorPtr; + +class Recognizer final : public BaseModel { +public: + explicit Recognizer(const std::string &modelSource, + std::shared_ptr callInvoker); + std::pair, float> generate(const cv::Mat &grayImage); + +private: + std::pair, float> + postprocess(const Tensor &tensor) const; + + cv::Size modelImageSize; +}; +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognizerUtils.cpp b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognizerUtils.cpp new file mode 100644 index 0000000000..8c1ba1057f --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognizerUtils.cpp @@ -0,0 +1,201 @@ +#include "RecognizerUtils.h" + +namespace rnexecutorch::ocr { +cv::Mat softmax(const cv::Mat &inputs) { + cv::Mat maxVal; + cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); + cv::Mat expInputs; + cv::exp(inputs - cv::repeat(maxVal, 1, inputs.cols), expInputs); + cv::Mat sumExp; + cv::reduce(expInputs, sumExp, 1, cv::REDUCE_SUM, CV_32F); + cv::Mat softmaxOutput = expInputs / cv::repeat(sumExp, 1, inputs.cols); + + return softmaxOutput; +} + +std::vector sumProbabilityRows(const cv::Mat &matrix) { + std::vector sums; + sums.reserve(matrix.rows); + for (int32_t i = 0; i < matrix.rows; ++i) { + sums.push_back(cv::sum(matrix.row(i))[0]); + } + return sums; +} + +void divideMatrixByRows(cv::Mat &matrix, const std::vector &rowSums) { + for (int32_t i = 0; i < matrix.rows; ++i) { + matrix.row(i) /= rowSums[i]; + } +} + +ValuesAndIndices findMaxValuesIndices(const cv::Mat &mat) { + CV_Assert(mat.type() == CV_32F); + ValuesAndIndices result{}; + result.values.reserve(mat.rows); + result.indices.reserve(mat.rows); + + for (int32_t i = 0; i < mat.rows; ++i) { + double maxVal; + cv::Point maxLoc; + cv::minMaxLoc(mat.row(i), nullptr, &maxVal, nullptr, &maxLoc); + result.values.push_back(static_cast(maxVal)); + result.indices.push_back(maxLoc.x); + } + + return result; +} + +float confidenceScore(const std::vector &values, + const std::vector &indices) { + float product = 1.0f; + int32_t count = 0; + + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] != 0) { + product *= values[i]; + count++; + } + } + + if (count == 0) { + return 0.0f; + } + + const float n = static_cast(count); + const float exponent = 2.0f / std::sqrt(n); + return std::pow(product, exponent); +} + +cv::Rect extractBoundingBox(std::array &points) { + cv::Mat pointsMat(4, 1, CV_32FC2, points.data()); + return cv::boundingRect(pointsMat); +} + +cv::Mat characterBitMask(const cv::Mat &img) { + // 1. Determine if character is darker/lighter than background. + cv::Mat histogram; + int32_t histSize = 256; + float range[] = {0.0f, 256.0f}; + const float *histRange = {range}; + bool uniform = true; + bool accumulate = false; + + cv::calcHist(&img, 1, 0, cv::Mat(), histogram, 1, &histSize, &histRange, + uniform, accumulate); + + // Compare sum of darker (left half) vs brighter (right half) pixels. + const int32_t midPoint = histSize / 2; + double sumLeft = 0.0; + double sumRight = 0.0; + for (int32_t i = 0; i < midPoint; i++) { + sumLeft += histogram.at(i); + } + for (int32_t i = midPoint; i < histSize; i++) { + sumRight += histogram.at(i); + } + const int32_t thresholdType = + (sumLeft < sumRight) ? cv::THRESH_BINARY_INV : cv::THRESH_BINARY; + + // 2. Binarize using Otsu's method (auto threshold). + cv::Mat thresh; + cv::threshold(img, thresh, 0, 255, thresholdType + cv::THRESH_OTSU); + + // 3. Find the largest connected component near the center. + cv::Mat labels, stats, centroids; + const int32_t numLabels = cv::connectedComponentsWithStats( + thresh, labels, stats, centroids, 8, CV_32S); + + const int32_t height = thresh.rows; + const int32_t width = thresh.cols; + const int32_t minX = ocr::singleCharacterCenterThreshold * width; + const int32_t maxX = (1 - ocr::singleCharacterCenterThreshold) * width; + const int32_t minY = ocr::singleCharacterCenterThreshold * height; + const int32_t maxY = (1 - ocr::singleCharacterCenterThreshold) * height; + + int32_t selectedComponent = -1; + int32_t maxArea = -1; + for (int32_t i = 1; i < numLabels; i++) { // Skip background (label 0) + const int32_t area = stats.at(i, cv::CC_STAT_AREA); + const double cx = centroids.at(i, 0); + const double cy = centroids.at(i, 1); + + if ((minX < cx && cx < maxX && minY < cy && + cy < maxY && // check if centered + area > ocr::singleCharacterMinSize) && // check if large enough + area > maxArea) { + selectedComponent = i; + maxArea = area; + } + } + // 4. Extract the character and invert to white-on-black. + cv::Mat resultImage; + cv::Mat mask; + if (selectedComponent != -1) { + mask = (labels == selectedComponent); + img.copyTo(resultImage, mask); + } else { + resultImage = cv::Mat::zeros(img.size(), img.type()); + } + + cv::bitwise_not(resultImage, resultImage); + + return resultImage; +} + +cv::Mat cropImageWithBoundingBox(const cv::Mat &img, + const std::array &bbox, + const std::array &originalBbox, + const PaddingInfo &paddings, + const PaddingInfo &originalPaddings) { + if (originalBbox.empty()) { + throw std::runtime_error("Original bounding box cannot be empty."); + } + const Point topLeft = originalBbox[0]; + + std::vector points; + points.reserve(bbox.size()); + + for (const auto &point : bbox) { + Point transformedPoint = point; + + transformedPoint.x -= paddings.left; + transformedPoint.y -= paddings.top; + + transformedPoint.x *= paddings.resizeRatio; + transformedPoint.y *= paddings.resizeRatio; + + transformedPoint.x += topLeft.x; + transformedPoint.y += topLeft.y; + + transformedPoint.x -= originalPaddings.left; + transformedPoint.y -= originalPaddings.top; + + transformedPoint.x *= originalPaddings.resizeRatio; + transformedPoint.y *= originalPaddings.resizeRatio; + + points.emplace_back(transformedPoint.x, transformedPoint.y); + } + + cv::Rect rect = cv::boundingRect(points); + rect &= cv::Rect(0, 0, img.cols, img.rows); + if (rect.empty()) { + return {}; + } + auto croppedImage = img(rect).clone(); + return croppedImage; +} + +cv::Mat prepareForRecognition(const cv::Mat &originalImage, + const std::array &bbox, + const std::array &originalBbox, + const PaddingInfo &paddings, + const PaddingInfo &originalPaddings) { + auto croppedChar = cropImageWithBoundingBox(originalImage, bbox, originalBbox, + paddings, originalPaddings); + cv::cvtColor(croppedChar, croppedChar, cv::COLOR_BGR2GRAY); + cv::resize(croppedChar, croppedChar, + cv::Size(ocr::smallVerticalRecognizerWidth, ocr::recognizerHeight), + 0, 0, cv::INTER_AREA); + return croppedChar; +} +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognizerUtils.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognizerUtils.h new file mode 100644 index 0000000000..5bd5ab1d04 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/RecognizerUtils.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace rnexecutorch::ocr { +/** + * @brief Computes per row softmax funcion. + * Formula: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x))) for each + * row. + */ +cv::Mat softmax(const cv::Mat &inputs); + +/** + * @brief For each row of matrix computes {maxValue, index} pair. Returns a list + * of maxValues and a list of corresponding indices. + */ +ValuesAndIndices findMaxValuesIndices(const cv::Mat &mat); +std::vector sumProbabilityRows(const cv::Mat &matrix); +void divideMatrixByRows(cv::Mat &matrix, const std::vector &rowSums); +cv::Rect extractBoundingBox(std::array &points); + +/** + * @brief Computes confidence score for given values and indices vectors. + * Omits blank tokens. + * Formula: pow(\prod_{i=1}^{n}(p_i), 2/sqrt(n)), where n is a number of + * non-blank tokens, and p_i is the probability of i-th non-blank token. + * @details Formula derived from line 14 of + * https://github.com/JaidedAI/EasyOCR/blob/c4f3cd7225efd4f85451bd8b4a7646ae9a092420/easyocr/recognition.py#L14 + * @details 'Some say that it's a code, sent to us from god' + */ +float confidenceScore(const std::vector &values, + const std::vector &indices); + +cv::Mat characterBitMask(const cv::Mat &img); + +/** + * @brief Perform cropping of an image to a single character detector box. + * This function utilizes info about external bounding box and padding combined + * with internal bounding box and padding. + * It does so to preserve the best possible image quality. + */ +cv::Mat cropImageWithBoundingBox(const cv::Mat &img, + const std::array &bbox, + const std::array &originalBbox, + const PaddingInfo &paddings, + const PaddingInfo &originalPaddings); + +/** + * @brief Perform cropping, resizing and convert to grayscale to prepare image + * for Recognizer. + * + * Prepare for Recognition by following steps: + * 1. Crop image to the character bounding box, + * 2. Convert Image to gray. + * 3. Resize it to [smallVerticalRecognizerWidth x recognizerHeight] (64 x 64). + * + * @details it utilizes cropImageWithBoundingBox to perform specific cropping. + */ + +cv::Mat prepareForRecognition(const cv::Mat &originalImage, + const std::array &bbox, + const std::array &originalBbox, + const PaddingInfo &paddings, + const PaddingInfo &originalPaddings); +} // namespace rnexecutorch::ocr diff --git a/packages/react-native-executorch/common/rnexecutorch/models/ocr/Types.h b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Types.h new file mode 100644 index 0000000000..a5e4a2b44d --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/ocr/Types.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +namespace rnexecutorch { +namespace ocr { + +struct Point { + float x; + float y; +}; + +struct ValuesAndIndices { + std::vector values; + std::vector indices; +}; + +struct DetectorBBox { + std::array bbox; + float angle; +}; + +struct PaddingInfo { + float resizeRatio; + int32_t top; + int32_t left; +}; +} // namespace ocr + +struct OCRDetection { + std::array bbox; + std::string text; + float score; +}; +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.cpp new file mode 100644 index 0000000000..1a269b8af3 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.cpp @@ -0,0 +1,92 @@ +#include "VerticalDetector.h" + +#include +#include +#include + +#include + +namespace rnexecutorch { +VerticalDetector::VerticalDetector( + const std::string &modelSource, bool detectSingleCharacters, + std::shared_ptr callInvoker) + : BaseModel(modelSource, callInvoker) { + this->detectSingleCharacters = detectSingleCharacters; + auto inputShapes = getAllInputShapes(); + if (inputShapes.empty()) { + throw std::runtime_error( + "Detector model seems to not take any input tensors."); + } + std::vector modelInputShape = inputShapes[0]; + if (modelInputShape.size() < 2) { + throw std::runtime_error("Unexpected detector model input size, expected " + "at least 2 dimensions but got: " + + std::to_string(modelInputShape.size()) + "."); + } + modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1], + modelInputShape[modelInputShape.size() - 2]); +} + +cv::Size VerticalDetector::getModelImageSize() const noexcept { + return modelImageSize; +} + +std::vector +VerticalDetector::generate(const cv::Mat &inputImage) { + auto inputShapes = getAllInputShapes(); + cv::Mat resizedInputImage = + imageprocessing::resizePadded(inputImage, getModelImageSize()); + TensorPtr inputTensor = imageprocessing::getTensorFromMatrix( + inputShapes[0], resizedInputImage, ocr::mean, ocr::variance); + auto forwardResult = BaseModel::forward(inputTensor); + if (!forwardResult.ok()) { + throw std::runtime_error( + "Failed to forward, error: " + + std::to_string(static_cast(forwardResult.error()))); + } + return postprocess(forwardResult->at(0).toTensor()); +} + +std::vector +VerticalDetector::postprocess(const Tensor &tensor) const { + /* + The output of the model consists of two matrices (heat maps): + 1. ScoreText(Score map) - The probability of a region containing character. + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to + group each character into a single instance (sequence) Both matrices are + H/2xW/2. + + The result of this step is a list of bounding boxes that contain text. + */ + std::span tensorData(tensor.const_data_ptr(), + tensor.numel()); + /* + The output of the model is a matrix half the size of the input image + containing two channels representing the heatmaps. + */ + auto [scoreTextMat, scoreAffinityMat] = ocr::interleavedArrayToMats( + tensorData, + cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)); + float txtThreshold = this->detectSingleCharacters + ? ocr::textThreshold + : ocr::textThresholdVertical; + std::vector bBoxesList = + ocr::getDetBoxesFromTextMapVertical(scoreTextMat, scoreAffinityMat, + txtThreshold, ocr::linkThreshold, + this->detectSingleCharacters); + const float restoreRatio = + ocr::calculateRestoreRatio(scoreTextMat.rows, ocr::recognizerImageSize); + ocr::restoreBboxRatio(bBoxesList, restoreRatio); + + // if this is Narrow Detector, do not group boxes. + if (!this->detectSingleCharacters) { + bBoxesList = ocr::groupTextBoxes( + bBoxesList, ocr::centerThreshold, ocr::distanceThreshold, + ocr::heightThreshold, ocr::minSideThreshold, ocr::maxSideThreshold, + ocr::maxWidth); + } + + return bBoxesList; +} + +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.h b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.h new file mode 100644 index 0000000000..aae898f16c --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalDetector.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include + +#include +#include + +namespace rnexecutorch { + +/* + Vertical Detector is an sligtly modified Detector tuned for detecting Vertical + text. For more details about standard detector, refer to the file + ocr/Detector.cpp. + + In Vertical OCR pipeline we make use of Detector two times: + + 1. Large Detector -- The differences between Detector used in standard OCR and + Large Detector used in Vertical OCR is: a) To obtain detected boxes from heeat + maps it utilizes `getDetBoxesFromTextMapVertical()` function rather than + 'getDetBoxesFromTextMap()`. Other than that, refer to the standard OCR + Detector. + + 2. Narrow Detector -- it is designed to detect a single characters bounding + boxes. `getDetBoxesFromTextMapVertical()` function acts differently for Narrow + Detector and different textThreshold Value is passed. Additionally, the + grouping of detected boxes is completely omited. + + Vertical Detector pipeline differentiate the Large Detector and Narrow + Detector based on `detectSingleCharacters` flag passed to the constructor. +*/ + +using executorch::aten::Tensor; +using executorch::extension::TensorPtr; + +class VerticalDetector final : public BaseModel { +public: + explicit VerticalDetector(const std::string &modelSource, + bool detectSingleCharacters, + std::shared_ptr callInvoker); + std::vector generate(const cv::Mat &inputImage); + cv::Size getModelImageSize() const noexcept; + +private: + bool detectSingleCharacters; + std::vector postprocess(const Tensor &tensor) const; + cv::Size modelImageSize; +}; +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp new file mode 100644 index 0000000000..450b135685 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.cpp @@ -0,0 +1,180 @@ +#include "VerticalOCR.h" +#include +#include +#include +#include +#include + +namespace rnexecutorch { +VerticalOCR::VerticalOCR(const std::string &detectorLargeSource, + const std::string &detectorNarrowSource, + const std::string &recognizerSource, + std::string symbols, bool independentChars, + std::shared_ptr invoker) + : detectorLarge(detectorLargeSource, false, invoker), + detectorNarrow(detectorNarrowSource, true, invoker), + recognizer(recognizerSource, invoker), converter(symbols), + independentCharacters(independentChars), callInvoker(invoker) {} + +std::vector VerticalOCR::generate(std::string input) { + cv::Mat image = imageprocessing::readImage(input); + if (image.empty()) { + throw std::runtime_error("Failed to load image from path: " + input); + } + // 1. Large Detector + std::vector largeBoxes = detectorLarge.generate(image); + + cv::Size largeDetectorSize = detectorLarge.getModelImageSize(); + cv::Mat resizedImage = + imageprocessing::resizePadded(image, largeDetectorSize); + ocr::PaddingInfo imagePaddings = + ocr::calculateResizeRatioAndPaddings(image.size(), largeDetectorSize); + + std::vector predictions; + predictions.reserve(largeBoxes.size()); + + for (auto &box : largeBoxes) { + predictions.push_back( + _processSingleTextBox(box, image, resizedImage, imagePaddings)); + } + + return predictions; +} + +std::size_t VerticalOCR::getMemoryLowerBound() const noexcept { + return detectorLarge.getMemoryLowerBound() + + detectorNarrow.getMemoryLowerBound() + + recognizer.getMemoryLowerBound(); +} + +// Strategy 1: Recognize each character individually +std::pair VerticalOCR::_handleIndependentCharacters( + const ocr::DetectorBBox &box, const cv::Mat &originalImage, + const std::vector &characterBoxes, + const ocr::PaddingInfo &paddingsBox, + const ocr::PaddingInfo &imagePaddings) { + std::string text; + float confidenceScore = 0.0f; + float totalScore = 0.0f; + for (const auto &characterBox : characterBoxes) { + + /* + Prepare for Recognition by following steps: + 1. Crop image to the character bounding box, + 2. Convert Image to gray. + 3. Resize it to [VerticalSmallRecognizerWidth x RecognizerHeight] (64 x + 64), + */ + auto croppedChar = ocr::prepareForRecognition( + originalImage, characterBox.bbox, box.bbox, paddingsBox, imagePaddings); + + /* + To make Recognition simpler, we convert cropped character image + to a bit mask with white character and black background. + */ + croppedChar = ocr::characterBitMask(croppedChar); + croppedChar = ocr::normalizeForRecognizer(croppedChar, + ocr::recognizerHeight, 0.0, true); + + const auto &[predIndex, score] = recognizer.generate(croppedChar); + if (!predIndex.empty()) { + text += converter.decodeGreedy(predIndex, predIndex.size())[0]; + } + totalScore += score; + } + confidenceScore = totalScore / characterBoxes.size(); + return {text, confidenceScore}; +} + +// Strategy 2: Concatenate characters and recognize as a single line +std::pair VerticalOCR::_handleJointCharacters( + const ocr::DetectorBBox &box, const cv::Mat &originalImage, + const std::vector &characterBoxes, + const ocr::PaddingInfo &paddingsBox, + const ocr::PaddingInfo &imagePaddings) { + std::string text; + std::vector croppedCharacters; + croppedCharacters.reserve(characterBoxes.size()); + for (const auto &characterBox : characterBoxes) { + /* + Prepare for Recognition by following steps: + 1. Crop image to the character bounding box, + 2. Convert Image to gray. + 3. Resize it to [smallVerticalRecognizerWidth x recognizerHeight] (64 x + 64). The same height is required for horizontal concatenation of single + characters into one image. + */ + auto croppedChar = ocr::prepareForRecognition( + originalImage, characterBox.bbox, box.bbox, paddingsBox, imagePaddings); + croppedCharacters.push_back(croppedChar); + } + + cv::Mat mergedCharacters; + cv::hconcat(croppedCharacters, mergedCharacters); + mergedCharacters = imageprocessing::resizePadded( + mergedCharacters, + cv::Size(ocr::largeRecognizerWidth, ocr::recognizerHeight)); + mergedCharacters = ocr::normalizeForRecognizer( + mergedCharacters, ocr::recognizerHeight, 0.0, false); + + const auto &[predIndex, confidenceScore] = + recognizer.generate(mergedCharacters); + if (!predIndex.empty()) { + text = converter.decodeGreedy(predIndex, predIndex.size())[0]; + } + return {text, confidenceScore}; +} + +OCRDetection VerticalOCR::_processSingleTextBox( + ocr::DetectorBBox &box, const cv::Mat &originalImage, + const cv::Mat &resizedLargeImage, const ocr::PaddingInfo &imagePaddings) { + cv::Rect boundingBox = ocr::extractBoundingBox(box.bbox); + + // Crop the image for detection of single characters. + cv::Rect safeRect = + boundingBox & cv::Rect(0, 0, resizedLargeImage.cols, + resizedLargeImage.rows); // ensure valid box + cv::Mat croppedLargeBox = resizedLargeImage(safeRect); + + // 2. Narrow Detector - detects single characters + std::vector characterBoxes = + detectorNarrow.generate(croppedLargeBox); + + std::string text; + float confidenceScore = 0.0; + if (!characterBoxes.empty()) { + // Prepare information useful for proper boxes shifting and image cropping. + const int32_t boxWidth = + static_cast(box.bbox[2].x - box.bbox[0].x); + const int32_t boxHeight = + static_cast(box.bbox[2].y - box.bbox[0].y); + cv::Size narrowRecognizerSize = detectorNarrow.getModelImageSize(); + ocr::PaddingInfo paddingsBox = ocr::calculateResizeRatioAndPaddings( + cv::Size(boxWidth, boxHeight), narrowRecognizerSize); + + // 3. Recognition - decide between Strategy 1 and Strategy 2. + std::tie(text, confidenceScore) = + independentCharacters + ? _handleIndependentCharacters(box, originalImage, characterBoxes, + paddingsBox, imagePaddings) + : _handleJointCharacters(box, originalImage, characterBoxes, + paddingsBox, imagePaddings); + } + // Modify the returned boxes to match the original image size + std::array finalBbox; + for (size_t i = 0; i < box.bbox.size(); ++i) { + finalBbox[i].x = + (box.bbox[i].x - imagePaddings.left) * imagePaddings.resizeRatio; + finalBbox[i].y = + (box.bbox[i].y - imagePaddings.top) * imagePaddings.resizeRatio; + } + + return {finalBbox, text, confidenceScore}; +} + +void VerticalOCR::unload() noexcept { + detectorLarge.unload(); + detectorNarrow.unload(); + recognizer.unload(); +} +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h new file mode 100644 index 0000000000..feccf3c723 --- /dev/null +++ b/packages/react-native-executorch/common/rnexecutorch/models/vertical_ocr/VerticalOCR.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace rnexecutorch { + +/* + Vertical OCR is OCR designed to handle vertical texts. + Vertical OCR pipeline consists of: + 1. Large Detector -- detects regions where text is located. + Almost identical to the Detector in standard OCR. + The result of this phase is a list of bounding boxes. + Each detected box is then processed individually through the following steps: + 2. Narrow Detector -- designed for detecting where single characters + are located. + There are two different strategies used for vertical recognition: + Strategy 1 "Independent Characters": + Treating each character region found by Narrow Detector + as compeletely independent. + 3. Each character is forwarded to Small Recognizer (64 x 64). + Strategy 2 "Joint Characters": + The bounding boxes found by Narrow Detector are + horizontally merged to create one wide image. + 3. One wide image is forwarded to Large Recognzer (512 x 64). + Vertical OCR differentiate between those two strategies based on + `independentChars` flag passed to the constructor. +*/ + +using executorch::aten::Tensor; +using executorch::extension::TensorPtr; + +class VerticalOCR final { +public: + explicit VerticalOCR(const std::string &detectorLargeSource, + const std::string &detectorNarrowSource, + const std::string &recognizerSource, std::string symbols, + bool indpendentCharacters, + std::shared_ptr callInvoker); + std::vector generate(std::string input); + std::size_t getMemoryLowerBound() const noexcept; + void unload() noexcept; + +private: + std::pair _handleIndependentCharacters( + const ocr::DetectorBBox &box, const cv::Mat &originalImage, + const std::vector &characterBoxes, + const ocr::PaddingInfo &paddingsBox, + const ocr::PaddingInfo &imagePaddings); + std::pair + _handleJointCharacters(const ocr::DetectorBBox &box, + const cv::Mat &originalImage, + const std::vector &characterBoxes, + const ocr::PaddingInfo &paddingsBox, + const ocr::PaddingInfo &imagePaddings); + OCRDetection _processSingleTextBox(ocr::DetectorBBox &box, + const cv::Mat &originalImage, + const cv::Mat &resizedLargeImage, + const ocr::PaddingInfo &imagePaddings); + VerticalDetector detectorLarge; + VerticalDetector detectorNarrow; + Recognizer recognizer; + ocr::CTCLabelConverter converter; + bool independentCharacters; + std::shared_ptr callInvoker; +}; + +} // namespace rnexecutorch diff --git a/packages/react-native-executorch/ios/RnExecutorch/OCR.h b/packages/react-native-executorch/ios/RnExecutorch/OCR.h deleted file mode 100644 index 4994108bce..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/OCR.h +++ /dev/null @@ -1,5 +0,0 @@ -#import - -@interface OCR : NSObject - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/OCR.mm b/packages/react-native-executorch/ios/RnExecutorch/OCR.mm deleted file mode 100644 index 69fe35f1ef..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/OCR.mm +++ /dev/null @@ -1,96 +0,0 @@ -#import "OCR.h" -#import "models/ocr/Detector.h" -#import "models/ocr/RecognitionHandler.h" -#import "models/ocr/utils/Constants.h" -#import "utils/ImageProcessor.h" - -@implementation OCR { - Detector *detector; - RecognitionHandler *recognitionHandler; -} - -RCT_EXPORT_MODULE() - -- (void)releaseResources { - detector = nil; - recognitionHandler = nil; -} - -- (void)loadModule:(NSString *)detectorSource - recognizerSourceLarge:(NSString *)recognizerSourceLarge - recognizerSourceMedium:(NSString *)recognizerSourceMedium - recognizerSourceSmall:(NSString *)recognizerSourceSmall - symbols:(NSString *)symbols - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { - detector = [[Detector alloc] init]; - NSNumber *errorCode = [detector loadModel:detectorSource]; - if ([errorCode intValue] != 0) { - [self releaseResources]; - NSError *error = [NSError - errorWithDomain:@"OCRErrorDomain" - code:[errorCode intValue] - userInfo:@{ - NSLocalizedDescriptionKey : [NSString - stringWithFormat:@"%ld", (long)[errorCode longValue]] - }]; - reject(@"init_module_error", @"Failed to initialize detector module", - error); - return; - } - - recognitionHandler = [[RecognitionHandler alloc] initWithSymbols:symbols]; - errorCode = [recognitionHandler loadRecognizers:recognizerSourceLarge - mediumRecognizerPath:recognizerSourceMedium - smallRecognizerPath:recognizerSourceSmall]; - if ([errorCode intValue] != 0) { - [self releaseResources]; - NSError *error = [NSError - errorWithDomain:@"OCRErrorDomain" - code:[errorCode intValue] - userInfo:@{ - NSLocalizedDescriptionKey : [NSString - stringWithFormat:@"%ld", (long)[errorCode longValue]] - }]; - reject(@"init_recognizer_error", - @"Failed to initialize one or more recognizer models", error); - return; - } - - resolve(@0); -} - -- (void)forward:(NSString *)input - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { - /* - The OCR consists of two phases: - 1. Detection - detecting text regions in the image, the result of this phase - is a list of bounding boxes. - 2. Recognition - recognizing the text in the bounding boxes, the result is a - list of strings and corresponding confidence scores. - - Recognition uses three models, each model is resposible for recognizing text - of different sizes (e.g. large - 512x64, medium - 256x64, small - 128x64). - */ - @try { - cv::Mat image = [ImageProcessor readImage:input]; - NSArray *result = [detector runModel:image]; - cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); - result = [self->recognitionHandler recognize:result - imgGray:image - desiredWidth:recognizerImageSize - desiredHeight:recognizerImageSize]; - resolve(result); - } @catch (NSException *exception) { - reject(@"forward_error", - [NSString stringWithFormat:@"%@", exception.reason], nil); - } -} - -- (std::shared_ptr)getTurboModule: - (const facebook::react::ObjCTurboModule::InitParams &)params { - return std::make_shared(params); -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/VerticalOCR.h b/packages/react-native-executorch/ios/RnExecutorch/VerticalOCR.h deleted file mode 100644 index 5692d37897..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/VerticalOCR.h +++ /dev/null @@ -1,5 +0,0 @@ -#import - -@interface VerticalOCR : NSObject - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/VerticalOCR.mm b/packages/react-native-executorch/ios/RnExecutorch/VerticalOCR.mm deleted file mode 100644 index 2683c5164f..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/VerticalOCR.mm +++ /dev/null @@ -1,183 +0,0 @@ -#import "VerticalOCR.h" -#import "models/ocr/Recognizer.h" -#import "models/ocr/VerticalDetector.h" -#import "models/ocr/utils/CTCLabelConverter.h" -#import "models/ocr/utils/Constants.h" -#import "models/ocr/utils/OCRUtils.h" -#import "models/ocr/utils/RecognizerUtils.h" -#import "utils/ImageProcessor.h" - -@implementation VerticalOCR { - VerticalDetector *detectorLarge; - VerticalDetector *detectorNarrow; - Recognizer *recognizer; - CTCLabelConverter *converter; - BOOL independentCharacters; -} - -RCT_EXPORT_MODULE() - -- (void)releaseResources { - detectorLarge = nil; - detectorNarrow = nil; - recognizer = nil; - converter = nil; -} - -- (void)loadModule:(NSString *)detectorLargeSource - detectorNarrowSource:(NSString *)detectorNarrowSource - recognizerSource:(NSString *)recognizerSource - symbols:(NSString *)symbols - independentCharacters:(BOOL)independentCharacters - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { - converter = [[CTCLabelConverter alloc] initWithCharacters:symbols - separatorList:@{}]; - self->independentCharacters = independentCharacters; - - detectorLarge = [[VerticalDetector alloc] initWithDetectSingleCharacters:NO]; - NSNumber *errorCode = [detectorLarge loadModel:detectorLargeSource]; - if ([errorCode intValue] != 0) { - [self releaseResources]; - reject(@"init_module_error", @"Failed to initialize detector module", nil); - return; - } - - detectorNarrow = - [[VerticalDetector alloc] initWithDetectSingleCharacters:YES]; - errorCode = [detectorNarrow loadModel:detectorNarrowSource]; - if ([errorCode intValue] != 0) { - [self releaseResources]; - reject(@"init_module_error", @"Failed to initialize detector module", nil); - return; - } - - recognizer = [[Recognizer alloc] init]; - errorCode = [recognizer loadModel:recognizerSource]; - if ([errorCode intValue] != 0) { - [self releaseResources]; - reject(@"init_module_error", @"Failed to initialize recognizer module", - nil); - return; - } - - resolve(@0); -} - -- (void)forward:(NSString *)input - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { - @try { - cv::Mat image = [ImageProcessor readImage:input]; - NSArray *result = [detectorLarge runModel:image]; - cv::Size largeDetectorSize = [detectorLarge getModelImageSize]; - cv::Mat resizedImage = - [OCRUtils resizeWithPadding:image - desiredWidth:largeDetectorSize.width - desiredHeight:largeDetectorSize.height]; - NSMutableArray *predictions = [NSMutableArray array]; - - for (NSDictionary *box in result) { - NSArray *cords = box[@"bbox"]; - const int boxWidth = [[cords objectAtIndex:2] CGPointValue].x - - [[cords objectAtIndex:0] CGPointValue].x; - const int boxHeight = [[cords objectAtIndex:2] CGPointValue].y - - [[cords objectAtIndex:0] CGPointValue].y; - - cv::Rect boundingBox = [OCRUtils extractBoundingBox:cords]; - cv::Mat croppedImage = resizedImage(boundingBox); - NSDictionary *paddings = [RecognizerUtils - calculateResizeRatioAndPaddings:image.cols - height:image.rows - desiredWidth:largeDetectorSize.width - desiredHeight:largeDetectorSize.height]; - - NSString *text = @""; - NSNumber *confidenceScore = @0.0; - NSArray *boxResult = [detectorNarrow runModel:croppedImage]; - std::vector croppedCharacters; - cv::Size narrowRecognizerSize = [detectorNarrow getModelImageSize]; - for (NSDictionary *characterBox in boxResult) { - NSArray *boxCords = characterBox[@"bbox"]; - NSDictionary *paddingsBox = [RecognizerUtils - calculateResizeRatioAndPaddings:boxWidth - height:boxHeight - desiredWidth:narrowRecognizerSize.width - desiredHeight:narrowRecognizerSize.height]; - cv::Mat croppedCharacter = - [RecognizerUtils cropImageWithBoundingBox:image - bbox:boxCords - originalBbox:cords - paddings:paddingsBox - originalPaddings:paddings]; - if (self->independentCharacters) { - croppedCharacter = - [RecognizerUtils cropSingleCharacter:croppedCharacter]; - croppedCharacter = - [RecognizerUtils normalizeForRecognizer:croppedCharacter - adjustContrast:0.0 - isVertical:YES]; - NSArray *recognitionResult = [recognizer runModel:croppedCharacter]; - NSArray *predIndex = [recognitionResult objectAtIndex:0]; - NSArray *decodedText = - [converter decodeGreedy:predIndex length:(int)(predIndex.count)]; - text = [text stringByAppendingString:decodedText[0]]; - confidenceScore = @([confidenceScore floatValue] + - [[recognitionResult objectAtIndex:1] floatValue]); - } else { - croppedCharacters.push_back(croppedCharacter); - } - } - - if (self->independentCharacters) { - confidenceScore = @([confidenceScore floatValue] / boxResult.count); - } else { - cv::Mat mergedCharacters; - cv::hconcat(croppedCharacters.data(), (int)croppedCharacters.size(), - mergedCharacters); - mergedCharacters = [OCRUtils resizeWithPadding:mergedCharacters - desiredWidth:largeRecognizerWidth - desiredHeight:recognizerHeight]; - mergedCharacters = - [RecognizerUtils normalizeForRecognizer:mergedCharacters - adjustContrast:0.0 - isVertical:NO]; - NSArray *recognitionResult = [recognizer runModel:mergedCharacters]; - NSArray *predIndex = [recognitionResult objectAtIndex:0]; - NSArray *decodedText = [converter decodeGreedy:predIndex - length:(int)(predIndex.count)]; - text = [text stringByAppendingString:decodedText[0]]; - confidenceScore = @([confidenceScore floatValue] + - [[recognitionResult objectAtIndex:1] floatValue]); - } - - NSMutableArray *newCoords = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *cord in cords) { - const CGPoint point = [cord CGPointValue]; - - [newCoords addObject:@{ - @"x" : @((point.x - [paddings[@"left"] intValue]) * - [paddings[@"resizeRatio"] floatValue]), - @"y" : @((point.y - [paddings[@"top"] intValue]) * - [paddings[@"resizeRatio"] floatValue]) - }]; - } - - NSDictionary *res = - @{@"text" : text, @"bbox" : newCoords, @"score" : confidenceScore}; - [predictions addObject:res]; - } - - resolve(predictions); - } @catch (NSException *exception) { - reject(@"forward_error", - [NSString stringWithFormat:@"%@", exception.reason], nil); - } -} - -- (std::shared_ptr)getTurboModule: - (const facebook::react::ObjCTurboModule::InitParams &)params { - return std::make_shared(params); -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/BaseModel.h b/packages/react-native-executorch/ios/RnExecutorch/models/BaseModel.h deleted file mode 100644 index a8bd4136e3..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/BaseModel.h +++ /dev/null @@ -1,21 +0,0 @@ -#import "ExecutorchLib/ETModel.h" - -@interface BaseModel : NSObject { -@protected - ETModel *module; -} - -- (NSArray *)forward:(NSArray *)inputs; - -- (NSArray *)forward:(NSArray *)inputs - shapes:(NSArray *)shapes - inputTypes:(NSArray *)inputTypes; - -- (NSArray *)execute:(NSString *)methodName - inputs:(NSArray *)inputs - shapes:(NSArray *)shapes - inputTypes:(NSArray *)inputTypes; - -- (NSNumber *)loadModel:(NSString *)modelSource; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/BaseModel.mm b/packages/react-native-executorch/ios/RnExecutorch/models/BaseModel.mm deleted file mode 100644 index b0f21ed6db..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/BaseModel.mm +++ /dev/null @@ -1,43 +0,0 @@ -#import "BaseModel.h" - -@implementation BaseModel - -- (NSArray *)forward:(NSArray *)inputs { - NSMutableArray *shapes = [NSMutableArray new]; - NSMutableArray *inputTypes = [NSMutableArray new]; - NSNumber *numberOfInputs = [module getNumberOfInputs]; - - for (NSUInteger i = 0; i < [numberOfInputs intValue]; i++) { - [shapes addObject:[module getInputShape:[NSNumber numberWithInt:i]]]; - [inputTypes addObject:[module getInputType:[NSNumber numberWithInt:i]]]; - } - - NSArray *result = [module forward:inputs shapes:shapes inputTypes:inputTypes]; - - return result; -} - -- (NSArray *)forward:(NSArray *)inputs - shapes:(NSArray *)shapes - inputTypes:(NSArray *)inputTypes { - NSArray *result = [module forward:inputs shapes:shapes inputTypes:inputTypes]; - return result; -} - -- (NSArray *)execute:(NSString *)methodName - inputs:(NSArray *)inputs - shapes:(NSArray *)shapes - inputTypes:(NSArray *)inputTypes { - NSArray *result = [module execute:methodName - inputs:inputs - shapes:shapes - inputTypes:inputTypes]; - return result; -} - -- (NSNumber *)loadModel:(NSString *)modelSource { - module = [[ETModel alloc] init]; - return [self->module loadModel:modelSource]; -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Detector.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Detector.h deleted file mode 100644 index e1a43898c1..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Detector.h +++ /dev/null @@ -1,9 +0,0 @@ -#import "../BaseModel.h" -#import "opencv2/opencv.hpp" - -@interface Detector : BaseModel - -- (cv::Size)getModelImageSize; -- (NSArray *)runModel:(cv::Mat &)input; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Detector.mm b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Detector.mm deleted file mode 100644 index 9df5ea36e2..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Detector.mm +++ /dev/null @@ -1,101 +0,0 @@ -#import "Detector.h" -#import "../../utils/ImageProcessor.h" -#import "utils/Constants.h" -#import "utils/DetectorUtils.h" -#import "utils/OCRUtils.h" - -/* - The model used as detector is based on CRAFT (Character Region Awareness for - Text Detection) paper. https://arxiv.org/pdf/1904.01941 - */ - -@implementation Detector { - cv::Size originalSize; - cv::Size modelSize; -} - -- (cv::Size)getModelImageSize { - if (!modelSize.empty()) { - return modelSize; - } - - NSArray *inputShape = [module getInputShape:@0]; - NSNumber *widthNumber = inputShape[inputShape.count - 2]; - NSNumber *heightNumber = inputShape.lastObject; - - const int height = [heightNumber intValue]; - const int width = [widthNumber intValue]; - modelSize = cv::Size(height, width); - - return cv::Size(height, width); -} - -- (NSArray *)preprocess:(cv::Mat &)input { - /* - Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. - Due to big influence of resize to quality of recognition the image preserves - original aspect ratio and the missing parts are filled with padding. - */ - self->originalSize = cv::Size(input.cols, input.rows); - cv::Size modelImageSize = [self getModelImageSize]; - cv::Mat resizedImage; - resizedImage = [OCRUtils resizeWithPadding:input - desiredWidth:modelImageSize.width - desiredHeight:modelImageSize.height]; - NSArray *modelInput = [ImageProcessor matToNSArray:resizedImage - mean:mean - variance:variance]; - return modelInput; -} - -- (NSArray *)postprocess:(NSArray *)output { - /* - The output of the model consists of two matrices (heat maps): - 1. ScoreText(Score map) - The probability of a region containing character - 2. ScoreAffinity(Affinity map) - affinity between characters, used to to - group each character into a single instance (sequence) Both matrices are - 400x400 - - The result of this step is a list of bounding boxes that contain text. - */ - NSArray *predictions = [output objectAtIndex:0]; - - cv::Size modelImageSize = [self getModelImageSize]; - cv::Mat scoreTextCV, scoreAffinityCV; - /* - The output of the model is a matrix in size of input image containing two - matrices representing heatmap. Those two matrices are in the size of half of - the input image, that's why the width and height is divided by 2. - */ - [DetectorUtils interleavedArrayToMats:predictions - outputMat1:scoreTextCV - outputMat2:scoreAffinityCV - withSize:cv::Size(modelImageSize.width / 2, - modelImageSize.height / 2)]; - NSArray *bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV - affinityMap:scoreAffinityCV - usingTextThreshold:textThreshold - linkThreshold:linkThreshold - lowTextThreshold:lowTextThreshold]; - bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList - usingRestoreRatio:restoreRatio]; - - bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList - centerThreshold:centerThreshold - distanceThreshold:distanceThreshold - heightThreshold:heightThreshold - minSideThreshold:minSideThreshold - maxSideThreshold:maxSideThreshold - maxWidth:maxWidth]; - - return bBoxesList; -} - -- (NSArray *)runModel:(cv::Mat &)input { - NSArray *modelInput = [self preprocess:input]; - NSArray *modelResult = [self forward:@[ modelInput ]]; - NSArray *result = [self postprocess:modelResult]; - return result; -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/RecognitionHandler.h deleted file mode 100644 index d0031c33d8..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/RecognitionHandler.h +++ /dev/null @@ -1,16 +0,0 @@ -#import "opencv2/opencv.hpp" - -@interface RecognitionHandler : NSObject - -- (instancetype)initWithSymbols:(NSString *)symbols; - -- (NSNumber *)loadRecognizers:(NSString *)largeRecognizerPath - mediumRecognizerPath:(NSString *)mediumRecognizerPath - smallRecognizerPath:(NSString *)smallRecognizerPath; - -- (NSArray *)recognize:(NSArray *)bBoxesList - imgGray:(cv::Mat)imgGray - desiredWidth:(int)desiredWidth - desiredHeight:(int)desiredHeight; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/RecognitionHandler.mm deleted file mode 100644 index b291ea430d..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ /dev/null @@ -1,135 +0,0 @@ -#import "RecognitionHandler.h" -#import "./utils/CTCLabelConverter.h" -#import "./utils/Constants.h" -#import "./utils/OCRUtils.h" -#import "./utils/RecognizerUtils.h" -#import "Recognizer.h" - -/* - RecognitionHandler class is responsible for loading and choosing the - appropriate recognizer model based on the input image size, it also handles - converting the model output to text. - */ - -@implementation RecognitionHandler { - Recognizer *recognizerLarge; - Recognizer *recognizerMedium; - Recognizer *recognizerSmall; - CTCLabelConverter *converter; -} - -- (instancetype)initWithSymbols:(NSString *)symbols { - self = [super init]; - if (self) { - recognizerLarge = [[Recognizer alloc] init]; - recognizerMedium = [[Recognizer alloc] init]; - recognizerSmall = [[Recognizer alloc] init]; - - converter = [[CTCLabelConverter alloc] initWithCharacters:symbols - separatorList:@{}]; - } - return self; -} - -- (NSNumber *)loadRecognizers:(NSString *)largeRecognizerPath - mediumRecognizerPath:(NSString *)mediumRecognizerPath - smallRecognizerPath:(NSString *)smallRecognizerPath { - NSArray *recognizers = - @[ recognizerLarge, recognizerMedium, recognizerSmall ]; - - NSArray *paths = - @[ largeRecognizerPath, mediumRecognizerPath, smallRecognizerPath ]; - - for (NSInteger i = 0; i < recognizers.count; i++) { - Recognizer *recognizer = recognizers[i]; - NSString *path = paths[i]; - - NSNumber *errorCode = [recognizer loadModel:path]; - if ([errorCode intValue] != 0) { - return errorCode; - } - } - - return @0; -} - -- (NSArray *)runModel:(cv::Mat)croppedImage { - NSArray *result; - if (croppedImage.cols >= largeRecognizerWidth) { - result = [recognizerLarge runModel:croppedImage]; - } else if (croppedImage.cols >= mediumRecognizerWidth) { - result = [recognizerMedium runModel:croppedImage]; - } else { - result = [recognizerSmall runModel:croppedImage]; - } - - return result; -} - -- (NSArray *)recognize:(NSArray *)bBoxesList - imgGray:(cv::Mat)imgGray - desiredWidth:(int)desiredWidth - desiredHeight:(int)desiredHeight { - NSDictionary *ratioAndPadding = - [RecognizerUtils calculateResizeRatioAndPaddings:imgGray.cols - height:imgGray.rows - desiredWidth:desiredWidth - desiredHeight:desiredHeight]; - const int left = [ratioAndPadding[@"left"] intValue]; - const int top = [ratioAndPadding[@"top"] intValue]; - const CGFloat resizeRatio = [ratioAndPadding[@"resizeRatio"] floatValue]; - imgGray = [OCRUtils resizeWithPadding:imgGray - desiredWidth:desiredWidth - desiredHeight:desiredHeight]; - - NSMutableArray *predictions = [NSMutableArray array]; - for (NSDictionary *box in bBoxesList) { - cv::Mat croppedImage = [RecognizerUtils getCroppedImage:box - image:imgGray - modelHeight:recognizerHeight]; - if (croppedImage.empty()) { - continue; - } - croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage - adjustContrast:adjustContrast - isVertical:NO]; - NSArray *result = [self runModel:croppedImage]; - - NSNumber *confidenceScore = [result objectAtIndex:1]; - if ([confidenceScore floatValue] < lowConfidenceThreshold) { - cv::rotate(croppedImage, croppedImage, cv::ROTATE_180); - - NSArray *rotatedResult = [self runModel:croppedImage]; - NSNumber *rotatedConfidenceScore = [rotatedResult objectAtIndex:1]; - - if ([rotatedConfidenceScore floatValue] > [confidenceScore floatValue]) { - result = rotatedResult; - confidenceScore = rotatedConfidenceScore; - } - } - - NSArray *predIndex = [result objectAtIndex:0]; - NSArray *decodedTexts = [converter decodeGreedy:predIndex - length:(int)(predIndex.count)]; - - NSMutableArray *bbox = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *coords in box[@"bbox"]) { - const CGPoint point = [coords CGPointValue]; - [bbox addObject:@{ - @"x" : @((point.x - left) * resizeRatio), - @"y" : @((point.y - top) * resizeRatio) - }]; - } - - NSDictionary *res = @{ - @"text" : decodedTexts[0], - @"bbox" : bbox, - @"score" : confidenceScore - }; - [predictions addObject:res]; - } - - return predictions; -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Recognizer.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Recognizer.h deleted file mode 100644 index 9d1cd81a04..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Recognizer.h +++ /dev/null @@ -1,8 +0,0 @@ -#import "../BaseModel.h" -#import "opencv2/opencv.hpp" - -@interface Recognizer : BaseModel - -- (NSArray *)runModel:(cv::Mat &)input; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Recognizer.mm b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Recognizer.mm deleted file mode 100644 index 2457727aa7..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/Recognizer.mm +++ /dev/null @@ -1,77 +0,0 @@ -#import "Recognizer.h" -#import "../../utils/ImageProcessor.h" -#import "RecognizerUtils.h" - -/* - The model used as detector is based on CRNN paper. - https://arxiv.org/pdf/1507.05717 - */ - -@implementation Recognizer { - cv::Size originalSize; -} - -- (cv::Size)getModelImageSize { - NSArray *inputShape = [module getInputShape:@0]; - NSNumber *widthNumber = inputShape.lastObject; - NSNumber *heightNumber = inputShape[inputShape.count - 2]; - - const int height = [heightNumber intValue]; - const int width = [widthNumber intValue]; - return cv::Size(height, width); -} - -- (cv::Size)getModelOutputSize { - NSArray *outputShape = [module getOutputShape:@0]; - NSNumber *widthNumber = outputShape.lastObject; - NSNumber *heightNumber = outputShape[outputShape.count - 2]; - - const int height = [heightNumber intValue]; - const int width = [widthNumber intValue]; - return cv::Size(height, width); -} - -- (NSArray *)preprocess:(cv::Mat &)input { - return [ImageProcessor matToNSArrayGray:input]; -} - -- (NSArray *)postprocess:(NSArray *)output { - const int modelOutputHeight = [self getModelOutputSize].height; - NSInteger numElements = [output.firstObject count]; - NSInteger numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight; - cv::Mat resultMat = cv::Mat::zeros(numRows, modelOutputHeight, CV_32F); - NSInteger counter = 0; - NSInteger currentRow = 0; - for (NSNumber *num in output.firstObject) { - resultMat.at(currentRow, counter) = [num floatValue]; - counter++; - if (counter >= modelOutputHeight) { - counter = 0; - currentRow++; - } - } - - cv::Mat probabilities = [RecognizerUtils softmax:resultMat]; - NSMutableArray *predsNorm = - [RecognizerUtils sumProbabilityRows:probabilities - modelOutputHeight:modelOutputHeight]; - probabilities = [RecognizerUtils divideMatrix:probabilities - byVector:predsNorm]; - NSArray *maxValuesIndices = - [RecognizerUtils findMaxValuesAndIndices:probabilities]; - const CGFloat confidenceScore = - [RecognizerUtils computeConfidenceScore:maxValuesIndices[0] - indicesArray:maxValuesIndices[1]]; - - return @[ maxValuesIndices[1], @(confidenceScore) ]; -} - -- (NSArray *)runModel:(cv::Mat &)input { - NSArray *modelInput = [self preprocess:input]; - NSArray *modelResult = [self forward:@[ modelInput ]]; - NSArray *result = [self postprocess:modelResult]; - - return result; -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h deleted file mode 100644 index 87a3e36be5..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h +++ /dev/null @@ -1,10 +0,0 @@ -#import "../BaseModel.h" -#import "opencv2/opencv.hpp" - -@interface VerticalDetector : BaseModel - -- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters; -- (cv::Size)getModelImageSize; -- (NSArray *)runModel:(cv::Mat &)input; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm deleted file mode 100644 index b3b7dcc663..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm +++ /dev/null @@ -1,118 +0,0 @@ -#import "VerticalDetector.h" -#import "../../utils/ImageProcessor.h" -#import "utils/Constants.h" -#import "utils/DetectorUtils.h" -#import "utils/OCRUtils.h" - -/* - The model used as detector is based on CRAFT (Character Region Awareness for - Text Detection) paper. https://arxiv.org/pdf/1904.01941 - */ - -@implementation VerticalDetector { - cv::Size originalSize; - cv::Size modelSize; - BOOL detectSingleCharacters; -} - -- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters { - self = [super init]; - if (self) { - self->detectSingleCharacters = detectSingleCharacters; - } - return self; -} - -- (cv::Size)getModelImageSize { - if (!modelSize.empty()) { - return modelSize; - } - - NSArray *inputShape = [module getInputShape:@0]; - NSNumber *widthNumber = inputShape[inputShape.count - 2]; - NSNumber *heightNumber = inputShape.lastObject; - - const int height = [heightNumber intValue]; - const int width = [widthNumber intValue]; - modelSize = cv::Size(height, width); - - return cv::Size(height, width); -} - -- (NSArray *)preprocess:(cv::Mat &)input { - /* - Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. - Due to big influence of resize to quality of recognition the image preserves - original aspect ratio and the missing parts are filled with padding. - */ - self->originalSize = cv::Size(input.cols, input.rows); - cv::Size modelImageSize = [self getModelImageSize]; - cv::Mat resizedImage; - resizedImage = [OCRUtils resizeWithPadding:input - desiredWidth:modelImageSize.width - desiredHeight:modelImageSize.height]; - NSArray *modelInput = [ImageProcessor matToNSArray:resizedImage - mean:mean - variance:variance]; - return modelInput; -} - -- (NSArray *)postprocess:(NSArray *)output { - /* - The output of the model consists of two matrices (heat maps): - 1. ScoreText(Score map) - The probability of a region containing character - 2. ScoreAffinity(Affinity map) - affinity between characters, used to to - group each character into a single instance (sequence) Both matrices are - 400x400 - - The result of this step is a list of bounding boxes that contain text. - */ - NSArray *predictions = [output objectAtIndex:0]; - - cv::Size modelImageSize = [self getModelImageSize]; - cv::Mat scoreTextCV, scoreAffinityCV; - /* - The output of the model is a matrix in size of input image containing two - matrices representing heatmap. Those two matrices are in the size of half of - the input image, that's why the width and height is divided by 2. - */ - [DetectorUtils interleavedArrayToMats:predictions - outputMat1:scoreTextCV - outputMat2:scoreAffinityCV - withSize:cv::Size(modelImageSize.width / 2, - modelImageSize.height / 2)]; - CGFloat txtThreshold = - (self->detectSingleCharacters) ? textThreshold : textThresholdVertical; - - NSArray *bBoxesList = [DetectorUtils - getDetBoxesFromTextMapVertical:scoreTextCV - affinityMap:scoreAffinityCV - usingTextThreshold:txtThreshold - linkThreshold:linkThreshold - independentCharacters:self->detectSingleCharacters]; - bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList - usingRestoreRatio:restoreRatioVertical]; - - if (self->detectSingleCharacters) { - return bBoxesList; - } - - bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList - centerThreshold:centerThreshold - distanceThreshold:distanceThreshold - heightThreshold:heightThreshold - minSideThreshold:minSideThreshold - maxSideThreshold:maxSideThreshold - maxWidth:maxWidth]; - - return bBoxesList; -} - -- (NSArray *)runModel:(cv::Mat &)input { - NSArray *modelInput = [self preprocess:input]; - NSArray *modelResult = [self forward:@[ modelInput ]]; - NSArray *result = [self postprocess:modelResult]; - return result; -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h deleted file mode 100644 index 498710dd03..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h +++ /dev/null @@ -1,16 +0,0 @@ -#import - -@interface CTCLabelConverter : NSObject - -@property(strong, nonatomic) NSMutableDictionary *dict; -@property(strong, nonatomic) NSArray *character; -@property(strong, nonatomic) NSDictionary *separatorList; -@property(strong, nonatomic) NSArray *ignoreIdx; -@property(strong, nonatomic) NSDictionary *dictList; - -- (instancetype)initWithCharacters:(NSString *)characters - separatorList:(NSDictionary *)separatorList; -- (NSArray *)decodeGreedy:(NSArray *)textIndex - length:(NSInteger)length; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm deleted file mode 100644 index 125da6134b..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm +++ /dev/null @@ -1,80 +0,0 @@ -#import "CTCLabelConverter.h" - -@implementation CTCLabelConverter - -- (instancetype)initWithCharacters:(NSString *)characters - separatorList:(NSDictionary *)separatorList { - self = [super init]; - if (self) { - _dict = [NSMutableDictionary dictionary]; - NSMutableArray *mutableCharacters = - [NSMutableArray arrayWithObject:@"[blank]"]; - - for (NSUInteger i = 0; i < [characters length]; i++) { - NSString *charStr = - [NSString stringWithFormat:@"%C", [characters characterAtIndex:i]]; - [mutableCharacters addObject:charStr]; - self.dict[charStr] = @(i + 1); - } - - _character = [mutableCharacters copy]; - _separatorList = separatorList; - - NSMutableArray *ignoreIndexes = [NSMutableArray arrayWithObject:@(0)]; - for (NSString *sep in separatorList.allValues) { - NSUInteger index = [characters rangeOfString:sep].location; - if (index != NSNotFound) { - [ignoreIndexes addObject:@(index)]; - } - } - _ignoreIdx = [ignoreIndexes copy]; - } - return self; -} - -- (NSArray *)decodeGreedy:(NSArray *)textIndex - length:(NSInteger)length { - NSMutableArray *texts = [NSMutableArray array]; - NSUInteger index = 0; - - while (index < textIndex.count) { - NSUInteger segmentLength = MIN(length, textIndex.count - index); - NSRange range = NSMakeRange(index, segmentLength); - NSArray *subArray = [textIndex subarrayWithRange:range]; - - NSMutableString *text = [NSMutableString string]; - NSNumber *lastChar = nil; - - NSMutableArray *isNotRepeated = - [NSMutableArray arrayWithObject:@YES]; - NSMutableArray *isNotIgnored = [NSMutableArray array]; - - for (NSUInteger i = 0; i < subArray.count; i++) { - NSNumber *currentChar = subArray[i]; - if (i > 0) { - [isNotRepeated addObject:@(![lastChar isEqualToNumber:currentChar])]; - } - [isNotIgnored addObject:@(![self.ignoreIdx containsObject:currentChar])]; - - lastChar = currentChar; - } - - for (NSUInteger j = 0; j < subArray.count; j++) { - if ([isNotRepeated[j] boolValue] && [isNotIgnored[j] boolValue]) { - NSUInteger charIndex = [subArray[j] unsignedIntegerValue]; - [text appendString:self.character[charIndex]]; - } - } - - [texts addObject:text.copy]; - index += segmentLength; - - if (segmentLength < length) { - break; - } - } - - return texts.copy; -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/Constants.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/Constants.h deleted file mode 100644 index ba1e162227..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/Constants.h +++ /dev/null @@ -1,26 +0,0 @@ -constexpr CGFloat textThreshold = 0.4; -constexpr CGFloat textThresholdVertical = 0.3; -constexpr CGFloat linkThreshold = 0.4; -constexpr CGFloat lowTextThreshold = 0.7; -constexpr CGFloat centerThreshold = 0.5; -constexpr CGFloat distanceThreshold = 2.0; -constexpr CGFloat heightThreshold = 2.0; -constexpr CGFloat restoreRatio = 3.2; -constexpr CGFloat restoreRatioVertical = 2.0; -constexpr CGFloat singleCharacterCenterThreshold = 0.3; -constexpr CGFloat lowConfidenceThreshold = 0.3; -constexpr CGFloat adjustContrast = 0.2; -constexpr int minSideThreshold = 15; -constexpr int maxSideThreshold = 30; -constexpr int recognizerHeight = 64; -constexpr int largeRecognizerWidth = 512; -constexpr int mediumRecognizerWidth = 256; -constexpr int smallRecognizerWidth = 128; -constexpr int smallVerticalRecognizerWidth = 64; -constexpr int maxWidth = largeRecognizerWidth + (largeRecognizerWidth * 0.15); -constexpr int minSize = 20; -constexpr int singleCharacterMinSize = 70; -constexpr int recognizerImageSize = 1280; - -const cv::Scalar mean(0.485, 0.456, 0.406); -const cv::Scalar variance(0.229, 0.224, 0.225); \ No newline at end of file diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h deleted file mode 100644 index 1c473b00cf..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ /dev/null @@ -1,31 +0,0 @@ -#import - -constexpr int verticalLineThreshold = 20; - -@interface DetectorUtils : NSObject - -+ (void)interleavedArrayToMats:(NSArray *)array - outputMat1:(cv::Mat &)mat1 - outputMat2:(cv::Mat &)mat2 - withSize:(cv::Size)size; -+ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap - affinityMap:(cv::Mat)affinityMap - usingTextThreshold:(CGFloat)textThreshold - linkThreshold:(CGFloat)linkThreshold - lowTextThreshold:(CGFloat)lowTextThreshold; -+ (NSArray *)getDetBoxesFromTextMapVertical:(cv::Mat)textMap - affinityMap:(cv::Mat)affinityMap - usingTextThreshold:(CGFloat)textThreshold - linkThreshold:(CGFloat)linkThreshold - independentCharacters:(BOOL)independentCharacters; -+ (NSArray *)restoreBboxRatio:(NSArray *)boxes - usingRestoreRatio:(CGFloat)restoreRatio; -+ (NSArray *)groupTextBoxes:(NSArray *)polys - centerThreshold:(CGFloat)centerThreshold - distanceThreshold:(CGFloat)distanceThreshold - heightThreshold:(CGFloat)heightThreshold - minSideThreshold:(int)minSideThreshold - maxSideThreshold:(int)maxSideThreshold - maxWidth:(int)maxWidth; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm deleted file mode 100644 index abc7268194..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ /dev/null @@ -1,754 +0,0 @@ -#import "DetectorUtils.h" - -@implementation DetectorUtils - -+ (void)interleavedArrayToMats:(NSArray *)array - outputMat1:(cv::Mat &)mat1 - outputMat2:(cv::Mat &)mat2 - withSize:(cv::Size)size { - mat1 = cv::Mat(size.height, size.width, CV_32F); - mat2 = cv::Mat(size.height, size.width, CV_32F); - - for (NSUInteger idx = 0; idx < array.count; idx++) { - const CGFloat value = [array[idx] doubleValue]; - const int x = (idx / 2) % size.width; - const int y = (idx / 2) / size.width; - - if (idx % 2 == 0) { - mat1.at(y, x) = value; - } else { - mat2.at(y, x) = value; - } - } -} - -+ (NSArray *)getDetBoxesFromTextMapVertical:(cv::Mat)textMap - affinityMap:(cv::Mat)affinityMap - usingTextThreshold:(CGFloat)textThreshold - linkThreshold:(CGFloat)linkThreshold - independentCharacters:(BOOL)independentCharacters { - const int imgH = textMap.rows; - const int imgW = textMap.cols; - cv::Mat textScore; - cv::Mat affinityScore; - cv::threshold(textMap, textScore, textThreshold, 1, cv::THRESH_BINARY); - cv::threshold(affinityMap, affinityScore, linkThreshold, 1, - cv::THRESH_BINARY); - cv::Mat textScoreComb; - if (independentCharacters) { - textScoreComb = textScore - affinityScore; - cv::threshold(textScoreComb, textScoreComb, 0.0, 0, cv::THRESH_TOZERO); - cv::threshold(textScoreComb, textScoreComb, 1.0, 1.0, cv::THRESH_TRUNC); - cv::erode(textScoreComb, textScoreComb, - cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), - cv::Point(-1, -1), 1); - cv::dilate(textScoreComb, textScoreComb, - cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), - cv::Point(-1, -1), 4); - } else { - textScoreComb = textScore + affinityScore; - cv::threshold(textScoreComb, textScoreComb, 0.0, 0, cv::THRESH_TOZERO); - cv::threshold(textScoreComb, textScoreComb, 1.0, 1.0, cv::THRESH_TRUNC); - cv::dilate(textScoreComb, textScoreComb, - cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), - cv::Point(-1, -1), 2); - } - - cv::Mat binaryMat; - textScoreComb.convertTo(binaryMat, CV_8UC1); - - cv::Mat labels, stats, centroids; - const int nLabels = - cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); - - NSMutableArray *detectedBoxes = [NSMutableArray array]; - for (int i = 1; i < nLabels; i++) { - const int area = stats.at(i, cv::CC_STAT_AREA); - if (area < 20) - continue; - const int width = stats.at(i, cv::CC_STAT_WIDTH); - const int height = stats.at(i, cv::CC_STAT_HEIGHT); - - if (!independentCharacters && height < width) - continue; - - cv::Mat mask = (labels == i); - - cv::Mat segMap = cv::Mat::zeros(textMap.size(), CV_8U); - segMap.setTo(255, mask); - - const int x = stats.at(i, cv::CC_STAT_LEFT); - const int y = stats.at(i, cv::CC_STAT_TOP); - const int w = stats.at(i, cv::CC_STAT_WIDTH); - const int h = stats.at(i, cv::CC_STAT_HEIGHT); - const int dilationRadius = (int)(sqrt((double)(area / MAX(w, h))) * 2.0); - const int sx = MAX(x - dilationRadius, 0); - const int ex = MIN(x + w + dilationRadius + 1, imgW); - const int sy = MAX(y - dilationRadius, 0); - const int ey = MIN(y + h + dilationRadius + 1, imgH); - - cv::Rect roi(sx, sy, ex - sx, ey - sy); - cv::Mat kernel = cv::getStructuringElement( - cv::MORPH_RECT, cv::Size(1 + dilationRadius, 1 + dilationRadius)); - cv::Mat roiSegMap = segMap(roi); - cv::dilate(roiSegMap, roiSegMap, kernel, cv::Point(-1, -1), 2); - - std::vector> contours; - cv::findContours(segMap, contours, cv::RETR_EXTERNAL, - cv::CHAIN_APPROX_SIMPLE); - if (!contours.empty()) { - cv::RotatedRect minRect = cv::minAreaRect(contours[0]); - cv::Point2f vertices[4]; - minRect.points(vertices); - NSMutableArray *pointsArray = [NSMutableArray arrayWithCapacity:4]; - for (int j = 0; j < 4; j++) { - const CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); - [pointsArray addObject:[NSValue valueWithCGPoint:point]]; - } - NSDictionary *dict = - @{@"bbox" : pointsArray, @"angle" : @(minRect.angle)}; - [detectedBoxes addObject:dict]; - } - } - - return detectedBoxes; -} - -/** - * This method applies a series of image processing operations to identify - * likely areas of text in the textMap and return the bounding boxes for single - * words. - * - * @param textMap A cv::Mat representing a heat map of the characters of text - * being present in an image. - * @param affinityMap A cv::Mat representing a heat map of the affinity between - * characters. - * @param textThreshold A CGFloat representing the threshold for the text map. - * @param linkThreshold A CGFloat representing the threshold for the affinity - * map. - * @param lowTextThreshold A CGFloat representing the low text. - * - * @return An NSArray containing NSDictionary objects. Each dictionary includes: - * - "bbox": an NSArray of CGPoint values representing the vertices of the - * detected text box. - * - "angle": an NSNumber representing the rotation angle of the box. - */ -+ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap - affinityMap:(cv::Mat)affinityMap - usingTextThreshold:(CGFloat)textThreshold - linkThreshold:(CGFloat)linkThreshold - lowTextThreshold:(CGFloat)lowTextThreshold { - const int imgH = textMap.rows; - const int imgW = textMap.cols; - cv::Mat textScore; - cv::Mat affinityScore; - cv::threshold(textMap, textScore, textThreshold, 1, cv::THRESH_BINARY); - cv::threshold(affinityMap, affinityScore, linkThreshold, 1, - cv::THRESH_BINARY); - cv::Mat textScoreComb = textScore + affinityScore; - cv::threshold(textScoreComb, textScoreComb, 0, 1, cv::THRESH_BINARY); - cv::Mat binaryMat; - textScoreComb.convertTo(binaryMat, CV_8UC1); - - cv::Mat labels, stats, centroids; - const int nLabels = - cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); - - NSMutableArray *detectedBoxes = [NSMutableArray array]; - for (int i = 1; i < nLabels; i++) { - const int area = stats.at(i, cv::CC_STAT_AREA); - if (area < 10) - continue; - - cv::Mat mask = (labels == i); - CGFloat maxVal; - cv::minMaxLoc(textMap, NULL, &maxVal, NULL, NULL, mask); - if (maxVal < lowTextThreshold) - continue; - - cv::Mat segMap = cv::Mat::zeros(textMap.size(), CV_8U); - segMap.setTo(255, mask); - - const int x = stats.at(i, cv::CC_STAT_LEFT); - const int y = stats.at(i, cv::CC_STAT_TOP); - const int w = stats.at(i, cv::CC_STAT_WIDTH); - const int h = stats.at(i, cv::CC_STAT_HEIGHT); - const int dilationRadius = (int)(sqrt((double)(area / MAX(w, h))) * 2.0); - const int sx = MAX(x - dilationRadius, 0); - const int ex = MIN(x + w + dilationRadius + 1, imgW); - const int sy = MAX(y - dilationRadius, 0); - const int ey = MIN(y + h + dilationRadius + 1, imgH); - - cv::Rect roi(sx, sy, ex - sx, ey - sy); - cv::Mat kernel = cv::getStructuringElement( - cv::MORPH_RECT, cv::Size(1 + dilationRadius, 1 + dilationRadius)); - cv::Mat roiSegMap = segMap(roi); - cv::dilate(roiSegMap, roiSegMap, kernel); - - std::vector> contours; - cv::findContours(segMap, contours, cv::RETR_EXTERNAL, - cv::CHAIN_APPROX_SIMPLE); - if (!contours.empty()) { - cv::RotatedRect minRect = cv::minAreaRect(contours[0]); - cv::Point2f vertices[4]; - minRect.points(vertices); - NSMutableArray *pointsArray = [NSMutableArray arrayWithCapacity:4]; - for (int j = 0; j < 4; j++) { - const CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); - [pointsArray addObject:[NSValue valueWithCGPoint:point]]; - } - NSDictionary *dict = - @{@"bbox" : pointsArray, @"angle" : @(minRect.angle)}; - [detectedBoxes addObject:dict]; - } - } - - return detectedBoxes; -} - -+ (NSArray *)restoreBboxRatio:(NSArray *)boxes - usingRestoreRatio:(CGFloat)restoreRatio { - NSMutableArray *result = [NSMutableArray array]; - for (NSUInteger i = 0; i < [boxes count]; i++) { - NSDictionary *box = boxes[i]; - NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *value in box[@"bbox"]) { - CGPoint point = [value CGPointValue]; - point.x *= restoreRatio; - point.y *= restoreRatio; - [boxArray addObject:[NSValue valueWithCGPoint:point]]; - } - NSDictionary *dict = @{@"bbox" : boxArray, @"angle" : box[@"angle"]}; - [result addObject:dict]; - } - - return result; -} - -/** - * This method normalizes angle returned from cv::minAreaRect function which - *ranges from 0 to 90 degrees. - **/ -+ (CGFloat)normalizeAngle:(CGFloat)angle { - if (angle > 45) { - return angle - 90; - } - return angle; -} - -+ (CGPoint)midpointBetweenPoint:(CGPoint)p1 andPoint:(CGPoint)p2 { - return CGPointMake((p1.x + p2.x) / 2, (p1.y + p2.y) / 2); -} - -+ (CGFloat)distanceFromPoint:(CGPoint)p1 toPoint:(CGPoint)p2 { - const CGFloat xDist = (p2.x - p1.x); - const CGFloat yDist = (p2.y - p1.y); - return sqrt(xDist * xDist + yDist * yDist); -} - -+ (CGPoint)centerOfBox:(NSArray *)box { - return [self midpointBetweenPoint:[box[0] CGPointValue] - andPoint:[box[2] CGPointValue]]; -} - -+ (CGFloat)maxSideLength:(NSArray *)points { - CGFloat maxSideLength = 0; - NSInteger numOfPoints = points.count; - for (NSInteger i = 0; i < numOfPoints; i++) { - const CGPoint currentPoint = [points[i] CGPointValue]; - const CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; - - const CGFloat sideLength = [self distanceFromPoint:currentPoint - toPoint:nextPoint]; - if (sideLength > maxSideLength) { - maxSideLength = sideLength; - } - } - return maxSideLength; -} - -+ (CGFloat)minSideLength:(NSArray *)points { - CGFloat minSideLength = CGFLOAT_MAX; - NSInteger numOfPoints = points.count; - - for (NSInteger i = 0; i < numOfPoints; i++) { - const CGPoint currentPoint = [points[i] CGPointValue]; - const CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; - - const CGFloat sideLength = [self distanceFromPoint:currentPoint - toPoint:nextPoint]; - if (sideLength < minSideLength) { - minSideLength = sideLength; - } - } - - return minSideLength; -} - -+ (CGFloat)calculateMinimalDistanceBetweenBox:(NSArray *)box1 - andBox:(NSArray *)box2 { - CGFloat minDistance = CGFLOAT_MAX; - for (NSValue *value1 in box1) { - const CGPoint corner1 = [value1 CGPointValue]; - for (NSValue *value2 in box2) { - const CGPoint corner2 = [value2 CGPointValue]; - const CGFloat distance = [self distanceFromPoint:corner1 toPoint:corner2]; - if (distance < minDistance) { - minDistance = distance; - } - } - } - return minDistance; -} - -+ (NSArray *)rotateBox:(NSArray *)box - withAngle:(CGFloat)angle { - const CGPoint center = [self centerOfBox:box]; - - const CGFloat radians = angle * M_PI / 180.0; - - NSMutableArray *rotatedPoints = - [NSMutableArray arrayWithCapacity:4]; - for (NSValue *value in box) { - const CGPoint point = [value CGPointValue]; - - const CGFloat translatedX = point.x - center.x; - const CGFloat translatedY = point.y - center.y; - - const CGFloat rotatedX = - translatedX * cos(radians) - translatedY * sin(radians); - const CGFloat rotatedY = - translatedX * sin(radians) + translatedY * cos(radians); - - const CGPoint rotatedPoint = - CGPointMake(rotatedX + center.x, rotatedY + center.y); - [rotatedPoints addObject:[NSValue valueWithCGPoint:rotatedPoint]]; - } - - return rotatedPoints; -} - -/** - * Orders a set of points in a clockwise direction starting with the top-left - * point. - * - * Process: - * 1. It iterates through each CGPoint extracted from the NSValues. - * 2. For each point, it calculates the sum (x + y) and difference (y - x) of - * the coordinates. - * 3. Points are classified into: - * - Top-left: Minimum sum. - * - Bottom-right: Maximum sum. - * - Top-right: Minimum difference. - * - Bottom-left: Maximum difference. - * 4. The points are ordered starting from the top-left in a clockwise manner: - * top-left, top-right, bottom-right, bottom-left. - */ -+ (NSArray *)orderPointsClockwise:(NSArray *)points { - CGPoint topLeft, topRight, bottomRight, bottomLeft; - CGFloat minSum = FLT_MAX; - CGFloat maxSum = -FLT_MAX; - CGFloat minDiff = FLT_MAX; - CGFloat maxDiff = -FLT_MAX; - - for (NSValue *value in points) { - const CGPoint pt = [value CGPointValue]; - const CGFloat sum = pt.x + pt.y; - const CGFloat diff = pt.y - pt.x; - - if (sum < minSum) { - minSum = sum; - topLeft = pt; - } - if (sum > maxSum) { - maxSum = sum; - bottomRight = pt; - } - if (diff < minDiff) { - minDiff = diff; - topRight = pt; - } - if (diff > maxDiff) { - maxDiff = diff; - bottomLeft = pt; - } - } - - NSArray *rect = @[ - [NSValue valueWithCGPoint:topLeft], [NSValue valueWithCGPoint:topRight], - [NSValue valueWithCGPoint:bottomRight], - [NSValue valueWithCGPoint:bottomLeft] - ]; - - return rect; -} - -+ (std::vector)pointsFromNSValues:(NSArray *)nsValues { - std::vector points; - for (NSValue *value in nsValues) { - const CGPoint point = [value CGPointValue]; - points.emplace_back(point.x, point.y); - } - return points; -} - -+ (NSArray *)nsValuesFromPoints:(cv::Point2f *)points - count:(int)count { - NSMutableArray *nsValues = - [[NSMutableArray alloc] initWithCapacity:count]; - for (int i = 0; i < count; i++) { - [nsValues addObject:[NSValue valueWithCGPoint:CGPointMake(points[i].x, - points[i].y)]]; - } - return nsValues; -} - -+ (NSArray *)mergeRotatedBoxes:(NSArray *)box1 - withBox:(NSArray *)box2 { - box1 = [self orderPointsClockwise:box1]; - box2 = [self orderPointsClockwise:box2]; - - std::vector points1 = [self pointsFromNSValues:box1]; - std::vector points2 = [self pointsFromNSValues:box2]; - - std::vector allPoints; - allPoints.insert(allPoints.end(), points1.begin(), points1.end()); - allPoints.insert(allPoints.end(), points2.begin(), points2.end()); - - std::vector hullIndices; - cv::convexHull(allPoints, hullIndices, false); - - std::vector hullPoints; - for (int idx : hullIndices) { - hullPoints.push_back(allPoints[idx]); - } - - cv::RotatedRect minAreaRect = cv::minAreaRect(hullPoints); - - cv::Point2f rectPoints[4]; - minAreaRect.points(rectPoints); - - return [self nsValuesFromPoints:rectPoints count:4]; -} - -+ (NSMutableArray *) - removeSmallBoxesFromArray:(NSArray *)boxes - usingMinSideThreshold:(CGFloat)minSideThreshold - maxSideThreshold:(CGFloat)maxSideThreshold { - NSMutableArray *filteredBoxes = [NSMutableArray array]; - - for (NSDictionary *box in boxes) { - const CGFloat maxSideLength = [self maxSideLength:box[@"bbox"]]; - const CGFloat minSideLength = [self minSideLength:box[@"bbox"]]; - if (minSideLength > minSideThreshold && maxSideLength > maxSideThreshold) { - [filteredBoxes addObject:box]; - } - } - - return filteredBoxes; -} - -+ (CGFloat)minimumYFromBox:(NSArray *)box { - __block CGFloat minY = CGFLOAT_MAX; - [box enumerateObjectsUsingBlock:^(NSValue *_Nonnull obj, NSUInteger idx, - BOOL *_Nonnull stop) { - const CGPoint pt = [obj CGPointValue]; - if (pt.y < minY) { - minY = pt.y; - } - }]; - return minY; -} - -/** - * This method calculates the distances between each sequential pair of points - * in a presumed quadrilateral, identifies the two shortest sides, and fits a - * linear model to the midpoints of these sides. It also evaluates whether the - * resulting line should be considered vertical based on a predefined threshold - * for the x-coordinate differences. - * - * If the line is vertical it is fitted as a function of x = my + c, otherwise - * as y = mx + c. - * - * @return A NSDictionary containing: - * - "slope": NSNumber representing the slope (m) of the line. - * - "intercept": NSNumber representing the line's intercept (c) with y-axis. - * - "isVertical": NSNumber (boolean) indicating whether the line is - * considered vertical. - */ -+ (NSDictionary *)fitLineToShortestSides:(NSArray *)points { - NSMutableArray *sides = [NSMutableArray array]; - NSMutableArray *midpoints = [NSMutableArray array]; - - for (int i = 0; i < 4; i++) { - const CGPoint p1 = [points[i] CGPointValue]; - const CGPoint p2 = [points[(i + 1) % 4] CGPointValue]; - - const CGFloat sideLength = [self distanceFromPoint:p1 toPoint:p2]; - [sides addObject:@{@"length" : @(sideLength), @"index" : @(i)}]; - [midpoints - addObject:[NSValue valueWithCGPoint:[self midpointBetweenPoint:p1 - andPoint:p2]]]; - } - - [sides - sortUsingDescriptors:@[ [NSSortDescriptor sortDescriptorWithKey:@"length" - ascending:YES] ]]; - - const CGPoint midpoint1 = - [midpoints [[sides [0] [@"index"] intValue]] CGPointValue]; - const CGPoint midpoint2 = - [midpoints [[sides [1] [@"index"] intValue]] CGPointValue]; - const CGFloat dx = fabs(midpoint2.x - midpoint1.x); - - CGFloat m, c; - BOOL isVertical; - - std::vector cvMidPoints = { - cv::Point2f(midpoint1.x, midpoint1.y), - cv::Point2f(midpoint2.x, midpoint2.y)}; - cv::Vec4f line; - - if (dx < verticalLineThreshold) { - for (auto &pt : cvMidPoints) - std::swap(pt.x, pt.y); - cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); - m = line[1] / line[0]; - c = line[3] - m * line[2]; - isVertical = YES; - } else { - cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); - m = line[1] / line[0]; - c = line[3] - m * line[2]; - isVertical = NO; - } - - return @{@"slope" : @(m), @"intercept" : @(c), @"isVertical" : @(isVertical)}; -} - -/** - * This method assesses each box from a provided array, checks its center - * against the center of a "current box", and evaluates its alignment with a - * specified line equation. The function specifically searches for the box whose - * center is closest to the current box, that has not been ignored, and fits - * within a defined distance from the line. - * - * @param boxes An NSArray of NSDictionary objects where each dictionary - * represents a box with keys "bbox" and "angle". "bbox" is an NSArray of - * NSValue objects each encapsulating CGPoint that define the box vertices. - * "angle" is a NSNumber representing the box's rotation angle. - * @param ignoredIdxs An NSSet of NSNumber objects representing indices of boxes - * to ignore in the evaluation. - * @param currentBox An NSArray of NSValue objects encapsulating CGPoints - * representing the current box to compare against. - * @param isVertical A pointer to a BOOL indicating if the line to compare - * distance to is vertical. - * @param m The slope (gradient) of the line against which the box's alignment - * is checked. - * @param c The y-intercept of the line equation y = mx + c. - * @param centerThreshold A multiplier to determine the threshold for the - * distance between the box's center and the line. - * - * @return A NSDictionary containing: - * - "idx" : NSNumber indicating the index of the found box in the - * original NSArray. - * - "boxHeight" : NSNumber representing the shortest side length of the - * found box. Returns nil if no suitable box is found. - */ -+ (NSDictionary *)findClosestBox:(NSArray *)boxes - ignoredIdxs:(NSSet *)ignoredIdxs - currentBox:(NSArray *)currentBox - isVertical:(BOOL)isVertical - m:(CGFloat)m - c:(CGFloat)c - centerThreshold:(CGFloat)centerThreshold { - CGFloat smallestDistance = CGFLOAT_MAX; - NSInteger idx = -1; - CGFloat boxHeight = 0; - const CGPoint centerOfCurrentBox = [self centerOfBox:currentBox]; - - for (NSUInteger i = 0; i < boxes.count; i++) { - if ([ignoredIdxs containsObject:@(i)]) { - continue; - } - NSArray *bbox = boxes[i][@"bbox"]; - const CGPoint centerOfProcessedBox = [self centerOfBox:bbox]; - const CGFloat distanceBetweenCenters = - [self distanceFromPoint:centerOfCurrentBox - toPoint:centerOfProcessedBox]; - - if (distanceBetweenCenters >= smallestDistance) { - continue; - } - - boxHeight = [self minSideLength:bbox]; - - const CGFloat lineDistance = - (isVertical - ? fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) - : fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c))); - - if (lineDistance < boxHeight * centerThreshold) { - idx = i; - smallestDistance = distanceBetweenCenters; - } - } - - return idx != -1 ? @{@"idx" : @(idx), @"boxHeight" : @(boxHeight)} : nil; -} - -/** - * This method processes an array of text box dictionaries, each containing - * details about individual text boxes, and attempts to group and merge these - * boxes based on specified criteria including proximity, alignment, and size - * thresholds. It prioritizes merging of boxes that are aligned closely in - * angle, are near each other, and whose sizes are compatible based on the given - * thresholds. - * - * @param boxes An array of NSDictionary objects where each dictionary - * represents a text box. Each dictionary must have at least a "bbox" key with - * an NSArray of NSValue wrapping CGPoints defining the box vertices, and an - * "angle" key indicating the orientation of the box. - * @param centerThreshold A CGFloat representing the threshold for considering - * the distance between center and fitted line. - * @param distanceThreshold A CGFloat that defines the maximum allowed distance - * between boxes for them to be considered for merging. - * @param heightThreshold A CGFloat representing the maximum allowed difference - * in height between boxes for merging. - * @param minSideThreshold An int that defines the minimum dimension threshold - * to filter out small boxes after grouping. - * @param maxSideThreshold An int that specifies the maximum dimension threshold - * for filtering boxes post-grouping. - * @param maxWidth An int that represents the maximum width allowable for a - * merged box. - * - * @return An NSArray of NSDictionary objects representing the merged boxes. - * Each dictionary contains: - * - "bbox": An NSArray of NSValue each containing a CGPoint that - * defines the vertices of the merged box. - * - "angle": NSNumber representing the computed orientation of the - * merged box. - * - * Processing Steps: - * 1. Sort initial boxes based on their maximum side length. - * 2. Sequentially merge boxes considering alignment, proximity, and size - * compatibility. - * 3. Post-processing to remove any boxes that are too small or exceed max side - * criteria. - * 4. Sort the final array of boxes by their vertical positions. - */ -+ (NSArray *)groupTextBoxes:(NSMutableArray *)boxes - centerThreshold:(CGFloat)centerThreshold - distanceThreshold:(CGFloat)distanceThreshold - heightThreshold:(CGFloat)heightThreshold - minSideThreshold:(int)minSideThreshold - maxSideThreshold:(int)maxSideThreshold - maxWidth:(int)maxWidth { - // Sort boxes based on their maximum side length - boxes = [boxes sortedArrayUsingComparator:^NSComparisonResult( - NSDictionary *obj1, NSDictionary *obj2) { - const CGFloat maxLen1 = [self maxSideLength:obj1[@"bbox"]]; - const CGFloat maxLen2 = [self maxSideLength:obj2[@"bbox"]]; - return (maxLen1 < maxLen2) ? NSOrderedDescending - : (maxLen1 > maxLen2) ? NSOrderedAscending - : NSOrderedSame; - }].mutableCopy; - - NSMutableArray *mergedArray = [NSMutableArray array]; - CGFloat lineAngle; - while (boxes.count > 0) { - NSMutableDictionary *currentBox = [boxes[0] mutableCopy]; - CGFloat normalizedAngle = - [self normalizeAngle:[currentBox[@"angle"] floatValue]]; - [boxes removeObjectAtIndex:0]; - NSMutableArray *ignoredIdxs = [NSMutableArray array]; - - while (YES) { - // Find all aligned boxes and merge them until max_size is reached or no - // more boxes can be merged - NSDictionary *fittedLine = - [self fitLineToShortestSides:currentBox[@"bbox"]]; - const CGFloat slope = [fittedLine[@"slope"] floatValue]; - const CGFloat intercept = [fittedLine[@"intercept"] floatValue]; - const BOOL isVertical = [fittedLine[@"isVertical"] boolValue]; - - lineAngle = atan(slope) * 180 / M_PI; - if (isVertical) { - lineAngle = -90; - } - - NSDictionary *closestBoxInfo = - [self findClosestBox:boxes - ignoredIdxs:[NSSet setWithArray:ignoredIdxs] - currentBox:currentBox[@"bbox"] - isVertical:isVertical - m:slope - c:intercept - centerThreshold:centerThreshold]; - if (closestBoxInfo == nil) - break; - - NSInteger candidateIdx = [closestBoxInfo[@"idx"] integerValue]; - NSMutableDictionary *candidateBox = [boxes[candidateIdx] mutableCopy]; - const CGFloat candidateHeight = [closestBoxInfo[@"boxHeight"] floatValue]; - - if (([candidateBox[@"angle"] isEqual:@90] && !isVertical) || - ([candidateBox[@"angle"] isEqual:@0] && isVertical)) { - candidateBox[@"bbox"] = [self rotateBox:candidateBox[@"bbox"] - withAngle:normalizedAngle]; - } - - const CGFloat minDistance = - [self calculateMinimalDistanceBetweenBox:candidateBox[@"bbox"] - andBox:currentBox[@"bbox"]]; - const CGFloat mergedHeight = [self minSideLength:currentBox[@"bbox"]]; - if (minDistance < distanceThreshold * candidateHeight && - fabs(mergedHeight - candidateHeight) < - candidateHeight * heightThreshold) { - currentBox[@"bbox"] = [self mergeRotatedBoxes:currentBox[@"bbox"] - withBox:candidateBox[@"bbox"]]; - [boxes removeObjectAtIndex:candidateIdx]; - [ignoredIdxs removeAllObjects]; - if ([self maxSideLength:currentBox[@"bbox"]] > maxWidth) { - break; - } - } else { - [ignoredIdxs addObject:@(candidateIdx)]; - } - } - - [mergedArray - addObject:@{@"bbox" : currentBox[@"bbox"], @"angle" : @(lineAngle)}]; - } - - // Remove small boxes and sort by vertical - mergedArray = [self removeSmallBoxesFromArray:mergedArray - usingMinSideThreshold:minSideThreshold - maxSideThreshold:maxSideThreshold]; - - NSArray *sortedBoxes = [mergedArray - sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, - NSDictionary *obj2) { - NSArray *coords1 = obj1[@"bbox"]; - NSArray *coords2 = obj2[@"bbox"]; - const CGFloat minY1 = [self minimumYFromBox:coords1]; - const CGFloat minY2 = [self minimumYFromBox:coords2]; - return (minY1 < minY2) ? NSOrderedAscending - : (minY1 > minY2) ? NSOrderedDescending - : NSOrderedSame; - }]; - - NSMutableArray *orderedSortedBoxes = - [[NSMutableArray alloc] initWithCapacity:[sortedBoxes count]]; - for (NSDictionary *dict in sortedBoxes) { - NSMutableDictionary *mutableDict = [dict mutableCopy]; - NSArray *originalBBox = mutableDict[@"bbox"]; - NSArray *orderedBBox = [self orderPointsClockwise:originalBBox]; - mutableDict[@"bbox"] = orderedBBox; - [orderedSortedBoxes addObject:mutableDict]; - } - - return orderedSortedBoxes; -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/OCRUtils.h deleted file mode 100644 index 90a8fa7a43..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/OCRUtils.h +++ /dev/null @@ -1,10 +0,0 @@ -#import - -@interface OCRUtils : NSObject - -+ (cv::Mat)resizeWithPadding:(cv::Mat)img - desiredWidth:(int)desiredWidth - desiredHeight:(int)desiredHeight; -+ (cv::Rect)extractBoundingBox:(NSArray *)coords; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm deleted file mode 100644 index a7a7a22d8a..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +++ /dev/null @@ -1,67 +0,0 @@ -#import "OCRUtils.h" - -@implementation OCRUtils - -+ (cv::Mat)resizeWithPadding:(cv::Mat)img - desiredWidth:(int)desiredWidth - desiredHeight:(int)desiredHeight { - const int height = img.rows; - const int width = img.cols; - const float heightRatio = (float)desiredHeight / height; - const float widthRatio = (float)desiredWidth / width; - const float resizeRatio = MIN(heightRatio, widthRatio); - - const int newWidth = width * resizeRatio; - const int newHeight = height * resizeRatio; - - cv::Mat resizedImg; - cv::resize(img, resizedImg, cv::Size(newWidth, newHeight), 0, 0, - cv::INTER_AREA); - - const int cornerPatchSize = MAX(1, MIN(height, width) / 30); - std::vector corners = { - img(cv::Rect(0, 0, cornerPatchSize, cornerPatchSize)), - img(cv::Rect(width - cornerPatchSize, 0, cornerPatchSize, - cornerPatchSize)), - img(cv::Rect(0, height - cornerPatchSize, cornerPatchSize, - cornerPatchSize)), - img(cv::Rect(width - cornerPatchSize, height - cornerPatchSize, - cornerPatchSize, cornerPatchSize))}; - - cv::Scalar backgroundScalar = cv::mean(corners[0]); - for (int i = 1; i < corners.size(); i++) { - backgroundScalar += cv::mean(corners[i]); - } - backgroundScalar /= (double)corners.size(); - - backgroundScalar[0] = cvFloor(backgroundScalar[0]); - backgroundScalar[1] = cvFloor(backgroundScalar[1]); - backgroundScalar[2] = cvFloor(backgroundScalar[2]); - - const int deltaW = desiredWidth - newWidth; - const int deltaH = desiredHeight - newHeight; - const int top = deltaH / 2; - const int bottom = deltaH - top; - const int left = deltaW / 2; - const int right = deltaW - left; - - cv::Mat centeredImg; - cv::copyMakeBorder(resizedImg, centeredImg, top, bottom, left, right, - cv::BORDER_CONSTANT, backgroundScalar); - - return centeredImg; -} - -+ (cv::Rect)extractBoundingBox:(NSArray *)coords { - std::vector points; - points.reserve(coords.count); - for (NSValue *value in coords) { - const CGPoint point = [value CGPointValue]; - - points.emplace_back(point.x, point.y); - } - - return cv::boundingRect(points); -} - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h deleted file mode 100644 index 51d93638a3..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h +++ /dev/null @@ -1,35 +0,0 @@ -#import - -@interface RecognizerUtils : NSObject - -+ (CGFloat)calculateRatio:(int)width height:(int)height; -+ (cv::Mat)computeRatioAndResize:(cv::Mat)img - width:(int)width - height:(int)height - modelHeight:(int)modelHeight; -+ (cv::Mat)normalizeForRecognizer:(cv::Mat)image - adjustContrast:(double)adjustContrast - isVertical:(BOOL)isVertical; -+ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target; -+ (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector; -+ (cv::Mat)softmax:(cv::Mat)inputs; -+ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width - height:(int)height - desiredWidth:(int)desiredWidth - desiredHeight:(int)desiredHeight; -+ (cv::Mat)getCroppedImage:(NSDictionary *)box - image:(cv::Mat)image - modelHeight:(int)modelHeight; -+ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities - modelOutputHeight:(int)modelOutputHeight; -+ (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities; -+ (double)computeConfidenceScore:(NSArray *)valuesArray - indicesArray:(NSArray *)indicesArray; -+ (cv::Mat)cropImageWithBoundingBox:(cv::Mat &)img - bbox:(NSArray *)bbox - originalBbox:(NSArray *)originalBbox - paddings:(NSDictionary *)paddings - originalPaddings:(NSDictionary *)originalPaddings; -+ (cv::Mat)cropSingleCharacter:(cv::Mat)img; - -@end diff --git a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm deleted file mode 100644 index 47f186daf8..0000000000 --- a/packages/react-native-executorch/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ /dev/null @@ -1,331 +0,0 @@ -#import "RecognizerUtils.h" -#import "Constants.h" -#import "OCRUtils.h" - -@implementation RecognizerUtils - -+ (CGFloat)calculateRatio:(int)width height:(int)height { - CGFloat ratio = (CGFloat)width / (CGFloat)height; - if (ratio < 1.0) { - ratio = 1.0 / ratio; - } - return ratio; -} - -+ (cv::Mat)computeRatioAndResize:(cv::Mat)img - width:(int)width - height:(int)height - modelHeight:(int)modelHeight { - CGFloat ratio = (CGFloat)width / (CGFloat)height; - if (ratio < 1.0) { - ratio = [self calculateRatio:width height:height]; - cv::resize(img, img, cv::Size(modelHeight, (int)(modelHeight * ratio)), 0, - 0, cv::INTER_LANCZOS4); - } else { - cv::resize(img, img, cv::Size((int)(modelHeight * ratio), modelHeight), 0, - 0, cv::INTER_LANCZOS4); - } - return img; -} - -+ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target { - double contrast = 0.0; - int high = 0; - int low = 255; - - for (int i = 0; i < img.rows; ++i) { - for (int j = 0; j < img.cols; ++j) { - uchar pixel = img.at(i, j); - high = MAX(high, pixel); - low = MIN(low, pixel); - } - } - contrast = (high - low) / 255.0; - - if (contrast < target) { - const double ratio = 200.0 / MAX(10, high - low); - img.convertTo(img, CV_32F); - img = ((img - low + 25) * ratio); - - cv::threshold(img, img, 255, 255, cv::THRESH_TRUNC); - cv::threshold(img, img, 0, 0, cv::THRESH_TOZERO); - - img.convertTo(img, CV_8U); - } - - return img; -} - -+ (cv::Mat)normalizeForRecognizer:(cv::Mat)image - adjustContrast:(double)adjustContrast - isVertical:(BOOL)isVertical { - if (adjustContrast > 0) { - image = [self adjustContrastGrey:image target:adjustContrast]; - } - - int desiredWidth = - (isVertical) ? smallVerticalRecognizerWidth : smallRecognizerWidth; - - if (image.cols >= largeRecognizerWidth) { - desiredWidth = largeRecognizerWidth; - } else if (image.cols >= mediumRecognizerWidth) { - desiredWidth = mediumRecognizerWidth; - } - - image = [OCRUtils resizeWithPadding:image - desiredWidth:desiredWidth - desiredHeight:recognizerHeight]; - - image.convertTo(image, CV_32F, 1.0 / 255.0); - image = (image - 0.5) * 2.0; - - return image; -} - -+ (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector { - cv::Mat result = matrix.clone(); - - for (int i = 0; i < matrix.rows; i++) { - const float divisor = [vector[i] floatValue]; - for (int j = 0; j < matrix.cols; j++) { - result.at(i, j) /= divisor; - } - } - - return result; -} - -+ (cv::Mat)softmax:(cv::Mat)inputs { - cv::Mat maxVal; - cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); - cv::Mat expInputs; - cv::exp(inputs - cv::repeat(maxVal, 1, inputs.cols), expInputs); - cv::Mat sumExp; - cv::reduce(expInputs, sumExp, 1, cv::REDUCE_SUM, CV_32F); - cv::Mat softmaxOutput = expInputs / cv::repeat(sumExp, 1, inputs.cols); - return softmaxOutput; -} - -+ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width - height:(int)height - desiredWidth:(int)desiredWidth - desiredHeight:(int)desiredHeight { - const float newRatioH = (float)desiredHeight / height; - const float newRatioW = (float)desiredWidth / width; - float resizeRatio = MIN(newRatioH, newRatioW); - const int newWidth = width * resizeRatio; - const int newHeight = height * resizeRatio; - const int deltaW = desiredWidth - newWidth; - const int deltaH = desiredHeight - newHeight; - const int top = deltaH / 2; - const int left = deltaW / 2; - const float heightRatio = (float)height / desiredHeight; - const float widthRatio = (float)width / desiredWidth; - - resizeRatio = MAX(heightRatio, widthRatio); - - return @{ - @"resizeRatio" : @(resizeRatio), - @"top" : @(top), - @"left" : @(left), - }; -} - -+ (cv::Mat)getCroppedImage:(NSDictionary *)box - image:(cv::Mat)image - modelHeight:(int)modelHeight { - NSArray *coords = box[@"bbox"]; - const CGFloat angle = [box[@"angle"] floatValue]; - - std::vector points; - for (NSValue *value in coords) { - const CGPoint point = [value CGPointValue]; - points.emplace_back(static_cast(point.x), - static_cast(point.y)); - } - - cv::RotatedRect rotatedRect = cv::minAreaRect(points); - - cv::Point2f imageCenter = cv::Point2f(image.cols / 2.0, image.rows / 2.0); - cv::Mat rotationMatrix = cv::getRotationMatrix2D(imageCenter, angle, 1.0); - cv::Mat rotatedImage; - cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), - cv::INTER_LINEAR); - cv::Point2f rectPoints[4]; - rotatedRect.points(rectPoints); - std::vector transformedPoints(4); - cv::Mat rectMat(4, 2, CV_32FC2, rectPoints); - cv::transform(rectMat, rectMat, rotationMatrix); - - for (int i = 0; i < 4; ++i) { - transformedPoints[i] = rectPoints[i]; - } - - cv::Rect boundingBox = cv::boundingRect(transformedPoints); - boundingBox &= cv::Rect(0, 0, rotatedImage.cols, rotatedImage.rows); - cv::Mat croppedImage = rotatedImage(boundingBox); - if (boundingBox.width == 0 || boundingBox.height == 0) { - croppedImage = cv::Mat().empty(); - - return croppedImage; - } - - croppedImage = [self computeRatioAndResize:croppedImage - width:boundingBox.width - height:boundingBox.height - modelHeight:modelHeight]; - - return croppedImage; -} - -+ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities - modelOutputHeight:(int)modelOutputHeight { - NSMutableArray *predsNorm = - [NSMutableArray arrayWithCapacity:probabilities.rows]; - for (int i = 0; i < probabilities.rows; i++) { - float sum = 0.0; - for (int j = 0; j < modelOutputHeight; j++) { - sum += probabilities.at(i, j); - } - [predsNorm addObject:@(sum)]; - } - return predsNorm; -} - -+ (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities { - NSMutableArray *valuesArray = [NSMutableArray array]; - NSMutableArray *indicesArray = [NSMutableArray array]; - for (int i = 0; i < probabilities.rows; i++) { - double maxVal = 0; - cv::Point maxLoc; - cv::minMaxLoc(probabilities.row(i), NULL, &maxVal, NULL, &maxLoc); - [valuesArray addObject:@(maxVal)]; - [indicesArray addObject:@(maxLoc.x)]; - } - return @[ valuesArray, indicesArray ]; -} - -+ (double)computeConfidenceScore:(NSArray *)valuesArray - indicesArray:(NSArray *)indicesArray { - NSMutableArray *predsMaxProb = [NSMutableArray array]; - for (NSUInteger index = 0; index < indicesArray.count; index++) { - NSNumber *indicator = indicesArray[index]; - if ([indicator intValue] != 0) { - [predsMaxProb addObject:valuesArray[index]]; - } - } - if (predsMaxProb.count == 0) { - [predsMaxProb addObject:@(0)]; - } - double product = 1.0; - for (NSNumber *prob in predsMaxProb) { - product *= [prob doubleValue]; - } - return pow(product, 2.0 / sqrt(predsMaxProb.count)); -} - -+ (cv::Mat)cropImageWithBoundingBox:(cv::Mat &)img - bbox:(NSArray *)bbox - originalBbox:(NSArray *)originalBbox - paddings:(NSDictionary *)paddings - originalPaddings:(NSDictionary *)originalPaddings { - CGPoint topLeft = [originalBbox[0] CGPointValue]; - std::vector points; - points.reserve(bbox.count); - for (NSValue *coords in bbox) { - CGPoint point = [coords CGPointValue]; - - point.x = point.x - [paddings[@"left"] intValue]; - point.y = point.y - [paddings[@"top"] intValue]; - - point.x = point.x * [paddings[@"resizeRatio"] floatValue]; - point.y = point.y * [paddings[@"resizeRatio"] floatValue]; - - point.x = point.x + topLeft.x; - point.y = point.y + topLeft.y; - - point.x = point.x - [originalPaddings[@"left"] intValue]; - point.y = point.y - [originalPaddings[@"top"] intValue]; - - point.x = point.x * [originalPaddings[@"resizeRatio"] floatValue]; - point.y = point.y * [originalPaddings[@"resizeRatio"] floatValue]; - - points.emplace_back(cv::Point2f(point.x, point.y)); - } - - cv::Rect rect = cv::boundingRect(points); - cv::Mat croppedImage = img(rect); - cv::cvtColor(croppedImage, croppedImage, cv::COLOR_BGR2GRAY); - cv::resize(croppedImage, croppedImage, - cv::Size(smallVerticalRecognizerWidth, recognizerHeight), 0, 0, - cv::INTER_AREA); - cv::medianBlur(img, img, 1); - return croppedImage; -} - -+ (cv::Mat)cropSingleCharacter:(cv::Mat)img { - - cv::Mat histogram; - - int histSize = 256; - float range[] = {0, 256}; - const float *histRange = {range}; - bool uniform = true, accumulate = false; - - cv::calcHist(&img, 1, 0, cv::Mat(), histogram, 1, &histSize, &histRange, - uniform, accumulate); - - int midPoint = histSize / 2; - - double sumLeft = 0.0, sumRight = 0.0; - for (int i = 0; i < midPoint; i++) { - sumLeft += histogram.at(i); - } - for (int i = midPoint; i < histSize; i++) { - sumRight += histogram.at(i); - } - - const int thresholdType = - (sumLeft < sumRight) ? cv::THRESH_BINARY_INV : cv::THRESH_BINARY; - - cv::Mat thresh; - cv::threshold(img, thresh, 0, 255, thresholdType + cv::THRESH_OTSU); - - cv::Mat labels, stats, centroids; - const int numLabels = - connectedComponentsWithStats(thresh, labels, stats, centroids, 8); - const CGFloat centralThreshold = singleCharacterCenterThreshold; - const int height = thresh.rows; - const int width = thresh.cols; - - const int minX = centralThreshold * width; - const int maxX = (1 - centralThreshold) * width; - const int minY = centralThreshold * height; - const int maxY = (1 - centralThreshold) * height; - - int selectedComponent = -1; - - for (int i = 1; i < numLabels; i++) { - const int area = stats.at(i, cv::CC_STAT_AREA); - const double cx = centroids.at(i, 0); - const double cy = centroids.at(i, 1); - - if (minX < cx && cx < maxX && minY < cy && cy < maxY && - area > singleCharacterMinSize) { - if (selectedComponent == -1 || - area > stats.at(selectedComponent, cv::CC_STAT_AREA)) { - selectedComponent = i; - } - } - } - cv::Mat mask = cv::Mat::zeros(img.size(), CV_8UC1); - if (selectedComponent != -1) { - mask = (labels == selectedComponent) / 255; - } - cv::Mat resultImage = cv::Mat::zeros(img.size(), img.type()); - img.copyTo(resultImage, mask); - cv::bitwise_not(resultImage, resultImage); - return resultImage; -} - -@end diff --git a/packages/react-native-executorch/src/controllers/OCRController.ts b/packages/react-native-executorch/src/controllers/OCRController.ts index 7bb6fc0251..e376f3ff42 100644 --- a/packages/react-native-executorch/src/controllers/OCRController.ts +++ b/packages/react-native-executorch/src/controllers/OCRController.ts @@ -1,41 +1,37 @@ import { symbols } from '../constants/ocr/symbols'; import { ETError, getError } from '../Error'; -import { OCRNativeModule } from '../native/RnExecutorchModules'; import { ResourceSource } from '../types/common'; import { OCRLanguage } from '../types/ocr'; import { ResourceFetcher } from '../utils/ResourceFetcher'; export class OCRController { - private nativeModule: typeof OCRNativeModule; + private nativeModule: any; public isReady: boolean = false; public isGenerating: boolean = false; public error: string | null = null; - private modelDownloadProgressCallback: (downloadProgress: number) => void; private isReadyCallback: (isReady: boolean) => void; private isGeneratingCallback: (isGenerating: boolean) => void; private errorCallback: (error: string) => void; constructor({ - modelDownloadProgressCallback = (_downloadProgress: number) => {}, isReadyCallback = (_isReady: boolean) => {}, isGeneratingCallback = (_isGenerating: boolean) => {}, errorCallback = (_error: string) => {}, - }) { - this.nativeModule = OCRNativeModule; - this.modelDownloadProgressCallback = modelDownloadProgressCallback; + } = {}) { this.isReadyCallback = isReadyCallback; this.isGeneratingCallback = isGeneratingCallback; this.errorCallback = errorCallback; } - public loadModel = async ( + public load = async ( detectorSource: ResourceSource, recognizerSources: { recognizerLarge: ResourceSource; recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; }, - language: OCRLanguage + language: OCRLanguage, + onDownloadProgressCallback?: (downloadProgress: number) => void ) => { try { if (!detectorSource || Object.keys(recognizerSources).length !== 3) @@ -49,7 +45,7 @@ export class OCRController { this.isReadyCallback(false); const paths = await ResourceFetcher.fetch( - this.modelDownloadProgressCallback, + onDownloadProgressCallback, detectorSource, recognizerSources.recognizerLarge, recognizerSources.recognizerMedium, @@ -58,14 +54,13 @@ export class OCRController { if (paths === null || paths?.length < 4) { throw new Error('Download interrupted!'); } - await this.nativeModule.loadModule( + this.nativeModule = global.loadOCR( paths[0]!, paths[1]!, paths[2]!, paths[3]!, symbols[language] ); - this.isReady = true; this.isReadyCallback(this.isReady); } catch (e) { @@ -88,7 +83,7 @@ export class OCRController { try { this.isGenerating = true; this.isGeneratingCallback(this.isGenerating); - return await this.nativeModule.forward(input); + return await this.nativeModule.generate(input); } catch (e) { throw new Error(getError(e)); } finally { @@ -96,4 +91,16 @@ export class OCRController { this.isGeneratingCallback(this.isGenerating); } }; + + public delete() { + if (this.isGenerating) { + throw new Error( + getError(ETError.ModelGenerating) + + 'You cannot delete the model. You must wait until the generating is finished.' + ); + } + this.nativeModule.unload(); + this.isReadyCallback(false); + this.isGeneratingCallback(false); + } } diff --git a/packages/react-native-executorch/src/controllers/VerticalOCRController.ts b/packages/react-native-executorch/src/controllers/VerticalOCRController.ts index 62cc412c1d..7fd3b7c846 100644 --- a/packages/react-native-executorch/src/controllers/VerticalOCRController.ts +++ b/packages/react-native-executorch/src/controllers/VerticalOCRController.ts @@ -1,34 +1,29 @@ import { symbols } from '../constants/ocr/symbols'; import { ETError, getError } from '../Error'; -import { VerticalOCRNativeModule } from '../native/RnExecutorchModules'; import { ResourceSource } from '../types/common'; import { OCRLanguage } from '../types/ocr'; import { ResourceFetcher } from '../utils/ResourceFetcher'; export class VerticalOCRController { - private ocrNativeModule: typeof VerticalOCRNativeModule; + private ocrNativeModule: any; public isReady: boolean = false; public isGenerating: boolean = false; public error: string | null = null; - private modelDownloadProgressCallback: (downloadProgress: number) => void; private isReadyCallback: (isReady: boolean) => void; private isGeneratingCallback: (isGenerating: boolean) => void; private errorCallback: (error: string) => void; constructor({ - modelDownloadProgressCallback = (_downloadProgress: number) => {}, isReadyCallback = (_isReady: boolean) => {}, isGeneratingCallback = (_isGenerating: boolean) => {}, errorCallback = (_error: string) => {}, - }) { - this.ocrNativeModule = VerticalOCRNativeModule; - this.modelDownloadProgressCallback = modelDownloadProgressCallback; + } = {}) { this.isReadyCallback = isReadyCallback; this.isGeneratingCallback = isGeneratingCallback; this.errorCallback = errorCallback; } - public loadModel = async ( + public load = async ( detectorSources: { detectorLarge: ResourceSource; detectorNarrow: ResourceSource; @@ -38,7 +33,8 @@ export class VerticalOCRController { recognizerSmall: ResourceSource; }, language: OCRLanguage, - independentCharacters: boolean + independentCharacters: boolean, + onDownloadProgressCallback: (downloadProgress: number) => void ) => { try { if ( @@ -55,7 +51,7 @@ export class VerticalOCRController { this.isReadyCallback(this.isReady); const paths = await ResourceFetcher.fetch( - this.modelDownloadProgressCallback, + onDownloadProgressCallback, detectorSources.detectorLarge, detectorSources.detectorNarrow, independentCharacters @@ -65,7 +61,7 @@ export class VerticalOCRController { if (paths === null || paths.length < 3) { throw new Error('Download interrupted'); } - await this.ocrNativeModule.loadModule( + this.ocrNativeModule = global.loadVerticalOCR( paths[0]!, paths[1]!, paths[2]!, @@ -95,7 +91,7 @@ export class VerticalOCRController { try { this.isGenerating = true; this.isGeneratingCallback(this.isGenerating); - return await this.ocrNativeModule.forward(input); + return await this.ocrNativeModule.generate(input); } catch (e) { throw new Error(getError(e)); } finally { @@ -103,4 +99,16 @@ export class VerticalOCRController { this.isGeneratingCallback(this.isGenerating); } }; + + public delete() { + if (this.isGenerating) { + throw new Error( + getError(ETError.ModelGenerating) + + 'You cannot delete the model. You must wait until the generating is finished.' + ); + } + this.ocrNativeModule.unload(); + this.isReadyCallback(false); + this.isGeneratingCallback(false); + } } diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts index 1aa5a2e51c..b4fe0bf31a 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts @@ -32,7 +32,6 @@ export const useOCR = ({ const controllerInstance = useMemo( () => new OCRController({ - modelDownloadProgressCallback: setDownloadProgress, isReadyCallback: setIsReady, isGeneratingCallback: setIsGenerating, errorCallback: setError, @@ -41,21 +40,24 @@ export const useOCR = ({ ); useEffect(() => { - const loadModel = async () => { - await controllerInstance.loadModel( + if (preventLoad) return; + + (async () => { + await controllerInstance.load( model.detectorSource, { recognizerLarge: model.recognizerLarge, recognizerMedium: model.recognizerMedium, recognizerSmall: model.recognizerSmall, }, - model.language + model.language, + setDownloadProgress ); - }; + })(); - if (!preventLoad) { - loadModel(); - } + return () => { + controllerInstance.delete(); + }; }, [ controllerInstance, model.detectorSource, diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts index 31ea5832d8..9144539ff4 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts @@ -34,7 +34,6 @@ export const useVerticalOCR = ({ const controllerInstance = useMemo( () => new VerticalOCRController({ - modelDownloadProgressCallback: setDownloadProgress, isReadyCallback: setIsReady, isGeneratingCallback: setIsGenerating, errorCallback: setError, @@ -46,7 +45,7 @@ export const useVerticalOCR = ({ if (preventLoad) return; (async () => { - await controllerInstance.loadModel( + await controllerInstance.load( { detectorLarge: model.detectorLarge, detectorNarrow: model.detectorNarrow, @@ -56,9 +55,14 @@ export const useVerticalOCR = ({ recognizerSmall: model.recognizerSmall, }, model.language, - independentCharacters + independentCharacters, + setDownloadProgress ); })(); + + return () => { + controllerInstance.delete(); + }; }, [ controllerInstance, model.detectorLarge, diff --git a/packages/react-native-executorch/src/index.tsx b/packages/react-native-executorch/src/index.tsx index 48471df8e8..464176816e 100644 --- a/packages/react-native-executorch/src/index.tsx +++ b/packages/react-native-executorch/src/index.tsx @@ -18,6 +18,20 @@ declare global { decoderSource: string, modelName: string ) => any; + var loadOCR: ( + detectorSource: string, + recognizerLarge: string, + recognizerMedium: string, + recognizerSmall: string, + symbols: string + ) => any; + var loadVerticalOCR: ( + detectorLarge: string, + detectorNarrow: string, + recognizer: string, + symbols: string, + independentCharacters?: boolean + ) => any; } // eslint-disable no-var if ( @@ -30,7 +44,9 @@ if ( global.loadTextEmbeddings == null || global.loadImageEmbeddings == null || global.loadLLM == null || - global.loadSpeechToText == null + global.loadSpeechToText == null || + global.loadOCR == null || + global.loadVerticalOCR == null ) { if (!ETInstallerNativeModule) { throw new Error( diff --git a/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts b/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts index e7ffd4cd9a..e9103814d6 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts @@ -3,9 +3,13 @@ import { ResourceSource } from '../../types/common'; import { OCRLanguage } from '../../types/ocr'; export class OCRModule { - static module: OCRController; + private controller: OCRController; - static async load( + constructor() { + this.controller = new OCRController(); + } + + async load( model: { detectorSource: ResourceSource; recognizerLarge: ResourceSource; @@ -15,22 +19,23 @@ export class OCRModule { }, onDownloadProgressCallback: (progress: number) => void = () => {} ) { - this.module = new OCRController({ - modelDownloadProgressCallback: onDownloadProgressCallback, - }); - - await this.module.loadModel( + await this.controller.load( model.detectorSource, { recognizerLarge: model.recognizerLarge, recognizerMedium: model.recognizerMedium, recognizerSmall: model.recognizerSmall, }, - model.language + model.language, + onDownloadProgressCallback ); } - static async forward(input: string) { - return await this.module.forward(input); + async forward(input: string) { + return await this.controller.forward(input); + } + + delete() { + this.controller.delete(); } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts index aaa4e83a10..9f6d9b5f27 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts @@ -3,9 +3,13 @@ import { ResourceSource } from '../../types/common'; import { OCRLanguage } from '../../types/ocr'; export class VerticalOCRModule { - static module: VerticalOCRController; + private controller: VerticalOCRController; - static async load( + constructor() { + this.controller = new VerticalOCRController(); + } + + async load( model: { detectorLarge: ResourceSource; detectorNarrow: ResourceSource; @@ -16,11 +20,7 @@ export class VerticalOCRModule { independentCharacters: boolean, onDownloadProgressCallback: (progress: number) => void = () => {} ) { - this.module = new VerticalOCRController({ - modelDownloadProgressCallback: onDownloadProgressCallback, - }); - - await this.module.loadModel( + await this.controller.load( { detectorLarge: model.detectorLarge, detectorNarrow: model.detectorNarrow, @@ -30,11 +30,16 @@ export class VerticalOCRModule { recognizerSmall: model.recognizerSmall, }, model.language, - independentCharacters + independentCharacters, + onDownloadProgressCallback ); } - static async forward(input: string) { - return await this.module.forward(input); + async forward(input: string) { + return await this.controller.forward(input); + } + + delete() { + this.controller.delete(); } } diff --git a/packages/react-native-executorch/src/native/NativeOCR.ts b/packages/react-native-executorch/src/native/NativeOCR.ts deleted file mode 100644 index 2c14c6ac0d..0000000000 --- a/packages/react-native-executorch/src/native/NativeOCR.ts +++ /dev/null @@ -1,16 +0,0 @@ -import type { TurboModule } from 'react-native'; -import { TurboModuleRegistry } from 'react-native'; -import { OCRDetection } from '../types/ocr'; - -export interface Spec extends TurboModule { - loadModule( - detectorSource: string, - recognizerSourceLarge: string, - recognizerSourceMedium: string, - recognizerSourceSmall: string, - symbols: string - ): Promise; - forward(input: string): Promise; -} - -export default TurboModuleRegistry.get('OCR'); diff --git a/packages/react-native-executorch/src/native/NativeVerticalOCR.ts b/packages/react-native-executorch/src/native/NativeVerticalOCR.ts deleted file mode 100644 index 2aca8cbebc..0000000000 --- a/packages/react-native-executorch/src/native/NativeVerticalOCR.ts +++ /dev/null @@ -1,16 +0,0 @@ -import type { TurboModule } from 'react-native'; -import { TurboModuleRegistry } from 'react-native'; -import { OCRDetection } from '../types/ocr'; - -export interface Spec extends TurboModule { - loadModule( - detectorLargeSource: string, - detectorNarrowSource: string, - recognizerSource: string, - symbols: string, - independentCharacters: boolean - ): Promise; - forward(input: string): Promise; -} - -export default TurboModuleRegistry.get('VerticalOCR'); diff --git a/packages/react-native-executorch/src/native/RnExecutorchModules.ts b/packages/react-native-executorch/src/native/RnExecutorchModules.ts index 6207477942..3cf4a10bbb 100644 --- a/packages/react-native-executorch/src/native/RnExecutorchModules.ts +++ b/packages/react-native-executorch/src/native/RnExecutorchModules.ts @@ -1,6 +1,4 @@ import { Platform } from 'react-native'; -import { Spec as OCRInterface } from './NativeOCR'; -import { Spec as VerticalOCRInterface } from './NativeVerticalOCR'; import { Spec as ETInstallerInterface } from './NativeETInstaller'; const LINKING_ERROR = @@ -22,12 +20,7 @@ function returnSpecOrThrowLinkingError(spec: any) { ); } -const OCRNativeModule: OCRInterface = returnSpecOrThrowLinkingError( - require('./NativeOCR').default -); -const VerticalOCRNativeModule: VerticalOCRInterface = - returnSpecOrThrowLinkingError(require('./NativeVerticalOCR').default); const ETInstallerNativeModule: ETInstallerInterface = returnSpecOrThrowLinkingError(require('./NativeETInstaller').default); -export { OCRNativeModule, VerticalOCRNativeModule, ETInstallerNativeModule }; +export { ETInstallerNativeModule };