-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathRNMLKitObjectDetectionModule.swift
More file actions
111 lines (93 loc) · 4.79 KB
/
RNMLKitObjectDetectionModule.swift
File metadata and controls
111 lines (93 loc) · 4.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import ExpoModulesCore
import RNMLKitCore
public struct RNMLKitObjectDetectionModelSpec: Record {
public init() {}
@Field
var modelName: String = ""
@Field
var modelPath: String = ""
@Field
var options: RNMLKitObjectDetectorOptionsRecord
}
public struct RNMLKitObjectDetectionCustomModelSpec: Record {
public init() {}
@Field
var modelName: String = ""
@Field
var modelPath: String = ""
@Field
var options: RNMLKitCustomObjectDetectorOptionsRecord
}
public class RNMLKitObjectDetectionModule: Module {
let ERROR_DOMAIN: String = "red.infinite.reactnativemlkit.ObjectDetectionErrorDomain" // 1. Moved inside the class
var objectDetectors: [String: RNMLKitObjectDetectorCommon] = [:] // 2. Handling multiple detectors
// Each module class must implement the definition function. The definition consists of components
// that describes the module's functionality and behavior.
// See https://docs.expo.dev/modules/module-api for more details about available components.
public func definition() -> ModuleDefinition {
Name("RNMLKitObjectDetection")
AsyncFunction("loadCustomModel") { (spec: RNMLKitObjectDetectionCustomModelSpec, promise: Promise) in
let regex = try! NSRegularExpression(pattern: "file://")
let trimmedPath = regex.stringByReplacingMatches(in: spec.modelPath, options: [], range: NSMakeRange(0, 7), withTemplate: "")
print("ExpoMLKItObjectDetection: Loading Custom Model name:\(String(describing: spec.modelName)) modelPath:\(trimmedPath)")
var customModelOptions: RNMLKitCustomObjectDetectorOptions
do {
customModelOptions = try RNMLKitCustomObjectDetectorOptions(record: spec.options)
} catch {
rejectPromiseWithMessage(promise: promise, message: "Error creating options object \(error.localizedDescription)", domain: ERROR_DOMAIN)
return
}
do {
let objectDetector = try RNMLKitCustomObjectDetector(name: spec.modelName, modelPath: trimmedPath, options: customModelOptions)
self.objectDetectors[spec.modelName] = objectDetector // 2. Store the detector
} catch {
rejectPromiseWithMessage(promise: promise, message: "Error instantiating object detector: \(error.localizedDescription)", domain: ERROR_DOMAIN)
return
}
promise.resolve()
}
AsyncFunction("loadDefaultModel") { (options: RNMLKitObjectDetectorOptionsRecord?, promise: Promise) in
if let _ = self.objectDetectors["default"] { // Check if the default model is already loaded
promise.resolve()
return
}
var defaultModelOptions: RNMLKitObjectDetectorOptions?
if let optionsRecord = options {
do {
defaultModelOptions = try RNMLKitObjectDetectorOptions(record: optionsRecord)
} catch {
rejectPromiseWithMessage(promise: promise, message: "Error creating options object \(error.localizedDescription)", domain: ERROR_DOMAIN)
return
}
}
do {
let objectDetector = RNMLKitObjectDetector(options: defaultModelOptions) // Instantiate the default detector with options
self.objectDetectors["default"] = objectDetector // Store the default detector
promise.resolve()
}
}
AsyncFunction("detectObjects") { (modelName: String, imagePath: String, promise: Promise) in
let logger = Logger(logHandlers: [createOSLogHandler(category: Logger.EXPO_LOG_CATEGORY)])
guard let objectDetector = self.objectDetectors[modelName] else { // 2. Retrieve the detector by name
logger.error("Model Not Found")
rejectPromiseWithMessage(promise: promise, message: "Model Not Found", domain: ERROR_DOMAIN)
return
}
Task {
do {
let result = try await objectDetector.detectObjects(imagePath: imagePath)
logger.info("detectObjects(\(modelName)): found \(result.count) objects")
logger.info("RNMLKitObjectDetection", "detectObjects: Detection completed successfully")
promise.resolve(result)
} catch {
rejectPromiseWithMessage(promise: promise, message: "Error Detecting Objects \(error)", domain: ERROR_DOMAIN)
return
}
}
}
Function("isLoaded") { (modelName: String) -> Bool in // 3. isLoaded function
let detector = self.objectDetectors[modelName]
return detector != nil
}
}
}