Skip to content

Commit 27071c7

Browse files
authored
Merge pull request #273 from mingmingtasd/dml_optimize
Refactor the DML backend
2 parents 4458a39 + 270b5a4 commit 27071c7

5 files changed

Lines changed: 2632 additions & 2600 deletions

File tree

src/webnn/native/BUILD.gn

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ source_set("sources") {
218218
"dml/ContextDML.h",
219219
"dml/GraphDML.cpp",
220220
"dml/GraphDML.h",
221+
"dml/DMLUtils.cpp",
221222
"dml/DMLUtils.h",
222223
]
223224
}

src/webnn/native/dml/DMLUtils.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Copyright 2021 The WebNN-native Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "DMLUtils.h"
16+
17+
namespace webnn::native::dml {
18+
19+
bool IsWarpAdapter(IDXGIAdapter1* pAdapter) {
20+
DXGI_ADAPTER_DESC1 pDesc;
21+
WEBNN_CHECK(pAdapter->GetDesc1(&pDesc));
22+
// See here for documentation on filtering WARP adapter:
23+
// https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8
24+
auto isBasicRenderDriverVendorId = pDesc.VendorId == 0x1414;
25+
auto isBasicRenderDriverDeviceId = pDesc.DeviceId == 0x8c;
26+
auto isSoftwareAdapter = pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE;
27+
return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId);
28+
}
29+
30+
void InitD3D12(ComPtr<ID3D12GraphicsCommandList>& commandList,
31+
ComPtr<ID3D12CommandQueue>& commandQueue,
32+
ComPtr<ID3D12CommandAllocator>& commandAllocator,
33+
ComPtr<ID3D12Device>& D3D12Device,
34+
DXGI_GPU_PREFERENCE gpuPreference,
35+
bool useGpu) {
36+
#if defined(_DEBUG)
37+
ComPtr<ID3D12Debug> debug;
38+
if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug)))) {
39+
debug->EnableDebugLayer();
40+
}
41+
#endif
42+
ComPtr<IDXGIAdapter1> dxgiAdapter;
43+
if (useGpu) {
44+
ComPtr<IDXGIFactory6> dxgiFactory;
45+
WEBNN_CHECK(CreateDXGIFactory1(IID_PPV_ARGS(&dxgiFactory)));
46+
UINT i = 0;
47+
while (dxgiFactory->EnumAdapterByGpuPreference(
48+
i++, gpuPreference, IID_PPV_ARGS(&dxgiAdapter)) != DXGI_ERROR_NOT_FOUND) {
49+
if (!IsWarpAdapter(dxgiAdapter.Get())) {
50+
break;
51+
}
52+
}
53+
}
54+
if (!useGpu || FAILED(D3D12CreateDevice(dxgiAdapter.Get(), D3D_FEATURE_LEVEL_11_0,
55+
IID_PPV_ARGS(&D3D12Device)))) {
56+
// If a computer's display driver is not functioning or is disabled, the computer's
57+
// primary (NULL) adapter might also be called "Microsoft Basic Render Driver."
58+
ComPtr<IDXGIFactory4> dxgiFactory;
59+
WEBNN_CHECK(CreateDXGIFactory1(IID_PPV_ARGS(&dxgiFactory)));
60+
WEBNN_CHECK(dxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&dxgiAdapter)));
61+
WEBNN_CHECK(D3D12CreateDevice(dxgiAdapter.Get(), D3D_FEATURE_LEVEL_11_0,
62+
IID_PPV_ARGS(&D3D12Device)));
63+
}
64+
65+
D3D12_COMMAND_QUEUE_DESC commandQueueDesc{};
66+
commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
67+
commandQueueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
68+
WEBNN_CHECK(
69+
D3D12Device->CreateCommandQueue(&commandQueueDesc, IID_PPV_ARGS(&commandQueue)));
70+
WEBNN_CHECK(D3D12Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT,
71+
IID_PPV_ARGS(&commandAllocator)));
72+
WEBNN_CHECK(D3D12Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT,
73+
commandAllocator.Get(), nullptr,
74+
IID_PPV_ARGS(&commandList)));
75+
}
76+
77+
void CloseExecuteResetWait(ComPtr<ID3D12GraphicsCommandList> commandList,
78+
ComPtr<ID3D12CommandQueue> commandQueue,
79+
ComPtr<ID3D12CommandAllocator> commandAllocator,
80+
ComPtr<ID3D12Device> D3D12Device) {
81+
WEBNN_CHECK(commandList->Close());
82+
ID3D12CommandList* commandLists[] = {commandList.Get()};
83+
commandQueue->ExecuteCommandLists(ARRAYSIZE(commandLists), commandLists);
84+
WEBNN_CHECK(commandQueue.Get()->GetDevice(IID_PPV_ARGS(D3D12Device.GetAddressOf())));
85+
ComPtr<ID3D12Fence> fence;
86+
WEBNN_CHECK(
87+
D3D12Device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(fence.GetAddressOf())));
88+
WEBNN_CHECK(commandQueue.Get()->Signal(fence.Get(), 1));
89+
WEBNN_CHECK(fence->SetEventOnCompletion(1, nullptr));
90+
WEBNN_CHECK(commandAllocator->Reset());
91+
WEBNN_CHECK(commandList->Reset(commandAllocator.Get(), nullptr));
92+
}
93+
} // namespace webnn::native::dml

0 commit comments

Comments
 (0)