-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathMain.cpp
More file actions
78 lines (69 loc) · 3.14 KB
/
Main.cpp
File metadata and controls
78 lines (69 loc) · 3.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
// Copyright 2021 The WebNN-native Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "examples/ResNet/ResNet.h"
int main(int argc, const char* argv[]) {
// Set input options for the example.
ResNet resnet;
if (!resnet.ParseAndCheckExampleOptions(argc, argv)) {
return -1;
}
// Pre-process the input image.
std::vector<float> processedPixels(resnet.mModelHeight * resnet.mModelWidth *
resnet.mModelChannels);
if (!utils::LoadAndPreprocessImage(&resnet, processedPixels)) {
return -1;
}
// Create a graph with weights and biases from .npy files.
const wnn::ContextOptions options =
utils::CreateContextOptions(resnet.mDevicePreference, resnet.mPowerPreference);
wnn::Context context = CreateCppContext(&options);
context.SetUncapturedErrorCallback(
[](WNNErrorType type, char const* message, void* userData) {
if (type != WNNErrorType_NoError) {
dawn::ErrorLog() << "Error type is " << type << ", message is " << message;
}
},
&resnet);
wnn::GraphBuilder builder = utils::CreateGraphBuilder(context);
wnn::Operand output =
resnet.mLayout == "nchw" ? resnet.LoadNchw(builder) : resnet.LoadNhwc(builder);
// Build the graph.
const std::chrono::time_point<std::chrono::high_resolution_clock> compilationStartTime =
std::chrono::high_resolution_clock::now();
wnn::Graph graph = utils::Build(builder, {{"output", output}});
if (!graph) {
dawn::ErrorLog() << "Failed to build graph.";
return -1;
}
const TIME_TYPE compilationElapsedTime =
std::chrono::high_resolution_clock::now() - compilationStartTime;
dawn::InfoLog() << "Compilation Time: " << compilationElapsedTime.count() << " ms";
// Compute the graph.
std::vector<float> result(utils::SizeOfShape(resnet.mOutputShape));
// Do the first inference for warming up if nIter > 1.
if (resnet.mNIter > 1) {
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
}
std::vector<TIME_TYPE> executionTime;
for (int i = 0; i < resnet.mNIter; ++i) {
std::chrono::time_point<std::chrono::high_resolution_clock> executionStartTime =
std::chrono::high_resolution_clock::now();
utils::Compute(graph, {{"input", processedPixels}}, {{"output", result}});
executionTime.push_back(std::chrono::high_resolution_clock::now() - executionStartTime);
}
// Print the result.
utils::PrintExexutionTime(executionTime);
utils::PrintResult(result, resnet.mLabelPath);
dawn::InfoLog() << "Done.";
}