Skip to content

Commit 9697ccc

Browse files
committed
Create and initialize Device object: the first PR for rewriting DML backend
1 parent 294de6b commit 9697ccc

10 files changed

Lines changed: 478 additions & 0 deletions

File tree

src/webnn/native/BUILD.gn

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,19 @@ source_set("sources") {
210210
}
211211
}
212212

213+
if (webnn_enable_dml) {
214+
sources += [
215+
"dml/BackendDML.cpp",
216+
"dml/BackendDML.h",
217+
"dml/ContextDML.cpp",
218+
"dml/ContextDML.h",
219+
"dml/ExecutionContextDML.cpp",
220+
"dml/ExecutionContextDML.h",
221+
"dml/GraphDML.cpp",
222+
"dml/GraphDML.h",
223+
]
224+
}
225+
213226
if (webnn_enable_dmlx) {
214227
if (webnn_enable_gpu_buffer == false) {
215228
sources += [
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright 2019 The Dawn Authors
2+
// Copyright 2022 The WebNN-native Authors
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#include "webnn/native/dml/BackendDML.h"
17+
18+
#include "webnn/native/Instance.h"
19+
#include "webnn/native/dml/ContextDML.h"
20+
21+
namespace webnn::native::dml {
22+
23+
Backend::Backend(InstanceBase* instance)
24+
: BackendConnection(instance, wnn::BackendType::DirectML) {
25+
}
26+
27+
MaybeError Backend::Initialize() {
28+
return {};
29+
}
30+
31+
ContextBase* Backend::CreateContext(ContextOptions const* options) {
32+
return new Context(options);
33+
}
34+
35+
BackendConnection* Connect(InstanceBase* instance) {
36+
Backend* backend = new Backend(instance);
37+
38+
if (instance->ConsumedError(backend->Initialize())) {
39+
delete backend;
40+
return nullptr;
41+
}
42+
43+
return backend;
44+
}
45+
46+
} // namespace webnn::native::dml

src/webnn/native/dml/BackendDML.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright 2019 The Dawn Authors
2+
// Copyright 2022 The WebNN-native Authors
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#ifndef WEBNN_NATIVE_DML_BACKEND_DML_H_
17+
#define WEBNN_NATIVE_DML_BACKEND_DML_H_
18+
19+
#include <memory>
20+
#include "webnn/native/BackendConnection.h"
21+
#include "webnn/native/Context.h"
22+
#include "webnn/native/Error.h"
23+
24+
namespace webnn::native::dml {
25+
26+
class Backend : public BackendConnection {
27+
public:
28+
Backend(InstanceBase* instance);
29+
30+
MaybeError Initialize();
31+
ContextBase* CreateContext(ContextOptions const* options = nullptr) override;
32+
33+
private:
34+
};
35+
36+
} // namespace webnn::native::dml
37+
38+
#endif // WEBNN_NATIVE_DML_BACKEND_DML_H_
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "webnn/native/dml/ContextDML.h"
16+
17+
#include "common/RefCounted.h"
18+
#include "webnn/native/dml/GraphDML.h"
19+
20+
namespace webnn::native::dml {
21+
22+
Context::Context(ContextOptions const* options) : ContextBase(options) {
23+
}
24+
25+
GraphBase* Context::CreateGraphImpl() {
26+
return new Graph(this);
27+
}
28+
29+
} // namespace webnn::native::dml

src/webnn/native/dml/ContextDML.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef WEBNN_NATIVE_DML_CONTEXT_DML_H_
16+
#define WEBNN_NATIVE_DML_CONTEXT_DML_H_
17+
18+
#include "webnn/native/Context.h"
19+
#include "webnn/native/Graph.h"
20+
21+
namespace webnn::native::dml {
22+
23+
class Context : public ContextBase {
24+
public:
25+
explicit Context(ContextOptions const* options);
26+
~Context() override = default;
27+
28+
private:
29+
GraphBase* CreateGraphImpl() override;
30+
};
31+
32+
} // namespace webnn::native::dml
33+
34+
#endif // WEBNN_NATIVE_DML_CONTEXT_DML_H_
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "ExecutionContextDML.h"
16+
17+
namespace webnn::native::dml {
18+
19+
// An adapter called the "Microsoft Basic Render Driver" is always present. This adapter is a
20+
// render-only device that has no display outputs.
21+
inline bool IsSoftwareAdapter(IDXGIAdapter1* pAdapter) {
22+
DXGI_ADAPTER_DESC1 pDesc;
23+
pAdapter->GetDesc1(&pDesc);
24+
// See here for documentation on filtering WARP adapter:
25+
// https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8
26+
return pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE ||
27+
(pDesc.VendorId == 0x1414 && pDesc.DeviceId == 0x8c);
28+
}
29+
30+
HRESULT EnumAdapter(DXGI_GPU_PREFERENCE gpuPreference,
31+
bool useGpu,
32+
ComPtr<IDXGIAdapter1> adapter) {
33+
ComPtr<IDXGIFactory6> dxgiFactory;
34+
RETURN_IF_FAILED(CreateDXGIFactory1(IID_PPV_ARGS(&dxgiFactory)));
35+
if (useGpu) {
36+
UINT adapterIndex = 0;
37+
while (dxgiFactory->EnumAdapterByGpuPreference(adapterIndex++, gpuPreference,
38+
IID_PPV_ARGS(&adapter)) !=
39+
DXGI_ERROR_NOT_FOUND) {
40+
if (!IsSoftwareAdapter(adapter.Get())) {
41+
break;
42+
}
43+
}
44+
} else {
45+
RETURN_IF_FAILED(dxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&adapter)));
46+
}
47+
return S_OK;
48+
}
49+
50+
ExecutionContext::ExecutionContext(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer)
51+
: mAdapter(std::move(adapter)), mUseDebugLayer(useDebugLayer) {
52+
}
53+
54+
// static
55+
std::unique_ptr<ExecutionContext> ExecutionContext::Create(ComPtr<IDXGIAdapter1> adapter,
56+
bool useDebugLayer) {
57+
std::unique_ptr<ExecutionContext> executionContext(
58+
new ExecutionContext(adapter, useDebugLayer));
59+
if (FAILED(executionContext->Initialize())) {
60+
dawn::ErrorLog() << "Failed to initialize Device.";
61+
return nullptr;
62+
}
63+
return executionContext;
64+
}
65+
66+
HRESULT ExecutionContext::Initialize() {
67+
if (mUseDebugLayer) {
68+
ComPtr<ID3D12Debug> debug;
69+
if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug)))) {
70+
debug->EnableDebugLayer();
71+
}
72+
}
73+
RETURN_IF_FAILED(
74+
D3D12CreateDevice(mAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&mD3D12Device)));
75+
D3D12_COMMAND_QUEUE_DESC commandQueueDesc{};
76+
commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
77+
commandQueueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
78+
RETURN_IF_FAILED(
79+
mD3D12Device->CreateCommandQueue(&commandQueueDesc, IID_PPV_ARGS(&mCommandQueue)));
80+
RETURN_IF_FAILED(mD3D12Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT,
81+
IID_PPV_ARGS(&mCommandAllocator)));
82+
RETURN_IF_FAILED(mD3D12Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT,
83+
mCommandAllocator.Get(), nullptr,
84+
IID_PPV_ARGS(&mCommandList)));
85+
86+
// Create the DirectML device.
87+
DML_CREATE_DEVICE_FLAGS dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_NONE;
88+
#if defined(_DEBUG)
89+
dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_DEBUG;
90+
#endif
91+
if (dmlCreateDeviceFlags == DML_CREATE_DEVICE_FLAG_DEBUG) {
92+
if (FAILED(DMLCreateDevice(mD3D12Device.Get(), dmlCreateDeviceFlags,
93+
IID_PPV_ARGS(&mDevice)))) {
94+
dawn::WarningLog() << "Failed to create a DirectML device with debug flag, "
95+
"will fall back to use none flag.";
96+
RETURN_IF_FAILED(DMLCreateDevice(mD3D12Device.Get(), DML_CREATE_DEVICE_FLAG_NONE,
97+
IID_PPV_ARGS(&mDevice)));
98+
}
99+
} else {
100+
RETURN_IF_FAILED(
101+
DMLCreateDevice(mD3D12Device.Get(), dmlCreateDeviceFlags, IID_PPV_ARGS(&mDevice)));
102+
}
103+
return S_OK;
104+
};
105+
106+
} // namespace webnn::native::dml
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// Copyright 2022 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_
16+
#define WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_
17+
18+
#include <webnn/webnn_cpp.h>
19+
#include <unordered_map>
20+
21+
#include "common/Log.h"
22+
#include "dml_platform.h"
23+
#include "webnn/native/NamedOutputs.h"
24+
#include "webnn/native/webnn_platform.h"
25+
26+
#define RETURN_IF_FAILED(EXPR) \
27+
do { \
28+
auto HR = EXPR; \
29+
if (FAILED(HR)) { \
30+
dawn::ErrorLog() << "Failed to do " << #EXPR << " Return HRESULT " << std::hex << HR; \
31+
return HR; \
32+
} \
33+
} while (0)
34+
35+
namespace webnn::native::dml {
36+
37+
HRESULT EnumAdapter(DXGI_GPU_PREFERENCE gpuPreference,
38+
bool useGpu,
39+
ComPtr<IDXGIAdapter1> adapter);
40+
41+
class ExecutionContext {
42+
public:
43+
static std::unique_ptr<ExecutionContext> Create(ComPtr<IDXGIAdapter1> adapter,
44+
bool useDebugLayer);
45+
46+
private:
47+
ExecutionContext(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer);
48+
HRESULT Initialize();
49+
50+
ComPtr<IDMLDevice> mDevice;
51+
ComPtr<ID3D12Device> mD3D12Device;
52+
ComPtr<IDMLCommandRecorder> mCommandRecorder;
53+
ComPtr<ID3D12CommandQueue> mCommandQueue;
54+
ComPtr<ID3D12CommandAllocator> mCommandAllocator;
55+
ComPtr<ID3D12GraphicsCommandList> mCommandList;
56+
57+
ComPtr<IDXGIAdapter1> mAdapter;
58+
bool mUseDebugLayer = false;
59+
};
60+
61+
} // namespace webnn::native::dml
62+
63+
#endif // WEBNN_NATIVE_DML_EXECUTIONCONTEXTEDML_H_

0 commit comments

Comments
 (0)