-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathContext.cpp
More file actions
128 lines (112 loc) · 5.32 KB
/
Context.cpp
File metadata and controls
128 lines (112 loc) · 5.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
// Copyright 2021 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "Context.h"
#include <napi.h>
#include <iostream>
#include "Graph.h"
#include "ML.h"
#include "Utils.h"
Napi::FunctionReference node::Context::constructor;
namespace node {
Context::Context(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Context>(info) {
wnn::ContextOptions options = {wnn::DevicePreference::Default,
wnn::PowerPreference::Default};
if (info.Length() > 0) {
Napi::Object optionsObject = info[0].As<Napi::Object>();
if (optionsObject.Has("powerPreference")) {
if (!optionsObject.Get("powerPreference").IsString()) {
Napi::Error::New(info.Env(), "Invaild powerPreference")
.ThrowAsJavaScriptException();
return;
}
std::string powerPreference = optionsObject.Get("powerPreference").ToString();
if (powerPreference == "default") {
options.powerPreference = wnn::PowerPreference::Default;
} else if (powerPreference == "low-power") {
options.powerPreference = wnn::PowerPreference::Low_power;
} else if (powerPreference == "high-performance") {
options.powerPreference = wnn::PowerPreference::High_performance;
} else {
Napi::Error::New(info.Env(), "Invaild powerPreference")
.ThrowAsJavaScriptException();
return;
}
}
if (optionsObject.Has("devicePreference")) {
if (!optionsObject.Get("devicePreference").IsString()) {
Napi::Error::New(info.Env(), "Invaild devicePreference")
.ThrowAsJavaScriptException();
return;
}
std::string devicePreference = optionsObject.Get("devicePreference").ToString();
if (devicePreference == "default") {
options.devicePreference = wnn::DevicePreference::Default;
} else if (devicePreference == "gpu") {
options.devicePreference = wnn::DevicePreference::Gpu;
} else if (devicePreference == "cpu") {
options.devicePreference = wnn::DevicePreference::Cpu;
} else {
Napi::Error::New(info.Env(), "Invaild devicePreference")
.ThrowAsJavaScriptException();
return;
}
}
}
mImpl = wnn::Context::Acquire(ML::GetInstance()->CreateContext(&options));
if (!mImpl) {
Napi::Error::New(info.Env(), "Failed to create Context").ThrowAsJavaScriptException();
return;
}
mImpl.SetUncapturedErrorCallback(
[](WNNErrorType type, char const* message, void* userData) {
if (type != WNNErrorType_NoError) {
std::cout << "Uncaptured Error type is " << type << ", message is " << message
<< std::endl;
}
},
this);
}
wnn::Context Context::GetImpl() {
return mImpl;
}
Napi::Object Context::Initialize(Napi::Env env, Napi::Object exports) {
Napi::HandleScope scope(env);
Napi::Function func = DefineClass(
env, "MLContext", {InstanceMethod("compute", &Context::Compute, napi_enumerable)});
constructor = Napi::Persistent(func);
constructor.SuppressDestruct();
exports.Set("MLContext", func);
return exports;
}
Napi::Value Context::Compute(const Napi::CallbackInfo& info) {
// status compute(NamedInputs inputs, NamedOutputs outputs);
WEBNN_NODE_ASSERT(info.Length() == 3, "The number of arguments is invalid.");
Napi::Object object = info[0].As<Napi::Object>();
node::Graph* jsGraph = Napi::ObjectWrap<node::Graph>::Unwrap(object);
std::map<std::string, Input> inputs;
WEBNN_NODE_ASSERT(GetNamedInputs(info[1], inputs), "The inputs parameter is invalid.");
std::map<std::string, wnn::Resource> outputs;
WEBNN_NODE_ASSERT(GetNamedOutputs(info[2], outputs), "The outputs parameter is invalid.");
wnn::NamedInputs namedInputs = wnn::CreateNamedInputs();
for (auto& input : inputs) {
namedInputs.Set(input.first.data(), input.second.AsPtr());
}
wnn::NamedOutputs namedOutputs = wnn::CreateNamedOutputs();
for (auto& output : outputs) {
namedOutputs.Set(output.first.data(), &output.second);
}
mImpl.ComputeSync(jsGraph->GetImpl(), namedInputs, namedOutputs);
return Napi::Number::New(info.Env(), 0);
}
} // namespace node