1+ // Copyright 2021 The WebNN-native Authors
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+ #include " DMLUtils.h"
16+
17+ namespace webnn ::native::dml {
18+
19+ bool IsWarpAdapter (IDXGIAdapter1* pAdapter) {
20+ DXGI_ADAPTER_DESC1 pDesc;
21+ WEBNN_CHECK (pAdapter->GetDesc1 (&pDesc));
22+ // See here for documentation on filtering WARP adapter:
23+ // https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8
24+ auto isBasicRenderDriverVendorId = pDesc.VendorId == 0x1414 ;
25+ auto isBasicRenderDriverDeviceId = pDesc.DeviceId == 0x8c ;
26+ auto isSoftwareAdapter = pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE;
27+ return isSoftwareAdapter || (isBasicRenderDriverVendorId && isBasicRenderDriverDeviceId);
28+ }
29+
30+ void InitD3D12 (ComPtr<ID3D12GraphicsCommandList>& commandList,
31+ ComPtr<ID3D12CommandQueue>& commandQueue,
32+ ComPtr<ID3D12CommandAllocator>& commandAllocator,
33+ ComPtr<ID3D12Device>& D3D12Device,
34+ DXGI_GPU_PREFERENCE gpuPreference,
35+ bool useGpu) {
36+ #if defined(_DEBUG)
37+ ComPtr<ID3D12Debug> debug;
38+ if (SUCCEEDED (D3D12GetDebugInterface (IID_PPV_ARGS (&debug)))) {
39+ debug->EnableDebugLayer ();
40+ }
41+ #endif
42+ ComPtr<IDXGIAdapter1> dxgiAdapter;
43+ if (useGpu) {
44+ ComPtr<IDXGIFactory6> dxgiFactory;
45+ WEBNN_CHECK (CreateDXGIFactory1 (IID_PPV_ARGS (&dxgiFactory)));
46+ UINT i = 0 ;
47+ while (dxgiFactory->EnumAdapterByGpuPreference (
48+ i++, gpuPreference, IID_PPV_ARGS (&dxgiAdapter)) != DXGI_ERROR_NOT_FOUND) {
49+ if (!IsWarpAdapter (dxgiAdapter.Get ())) {
50+ break ;
51+ }
52+ }
53+ }
54+ if (!useGpu || FAILED (D3D12CreateDevice (dxgiAdapter.Get (), D3D_FEATURE_LEVEL_11_0,
55+ IID_PPV_ARGS (&D3D12Device)))) {
56+ // If a computer's display driver is not functioning or is disabled, the computer's
57+ // primary (NULL) adapter might also be called "Microsoft Basic Render Driver."
58+ ComPtr<IDXGIFactory4> dxgiFactory;
59+ WEBNN_CHECK (CreateDXGIFactory1 (IID_PPV_ARGS (&dxgiFactory)));
60+ WEBNN_CHECK (dxgiFactory->EnumWarpAdapter (IID_PPV_ARGS (&dxgiAdapter)));
61+ WEBNN_CHECK (D3D12CreateDevice (dxgiAdapter.Get (), D3D_FEATURE_LEVEL_11_0,
62+ IID_PPV_ARGS (&D3D12Device)));
63+ }
64+
65+ D3D12_COMMAND_QUEUE_DESC commandQueueDesc{};
66+ commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
67+ commandQueueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
68+ WEBNN_CHECK (
69+ D3D12Device->CreateCommandQueue (&commandQueueDesc, IID_PPV_ARGS (&commandQueue)));
70+ WEBNN_CHECK (D3D12Device->CreateCommandAllocator (D3D12_COMMAND_LIST_TYPE_DIRECT,
71+ IID_PPV_ARGS (&commandAllocator)));
72+ WEBNN_CHECK (D3D12Device->CreateCommandList (0 , D3D12_COMMAND_LIST_TYPE_DIRECT,
73+ commandAllocator.Get (), nullptr ,
74+ IID_PPV_ARGS (&commandList)));
75+ }
76+
77+ void CloseExecuteResetWait (ComPtr<ID3D12GraphicsCommandList> commandList,
78+ ComPtr<ID3D12CommandQueue> commandQueue,
79+ ComPtr<ID3D12CommandAllocator> commandAllocator,
80+ ComPtr<ID3D12Device> D3D12Device) {
81+ WEBNN_CHECK (commandList->Close ());
82+ ID3D12CommandList* commandLists[] = {commandList.Get ()};
83+ commandQueue->ExecuteCommandLists (ARRAYSIZE (commandLists), commandLists);
84+ WEBNN_CHECK (commandQueue.Get ()->GetDevice (IID_PPV_ARGS (D3D12Device.GetAddressOf ())));
85+ ComPtr<ID3D12Fence> fence;
86+ WEBNN_CHECK (
87+ D3D12Device->CreateFence (0 , D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS (fence.GetAddressOf ())));
88+ WEBNN_CHECK (commandQueue.Get ()->Signal (fence.Get (), 1 ));
89+ WEBNN_CHECK (fence->SetEventOnCompletion (1 , nullptr ));
90+ WEBNN_CHECK (commandAllocator->Reset ());
91+ WEBNN_CHECK (commandList->Reset (commandAllocator.Get (), nullptr ));
92+ }
93+ } // namespace webnn::native::dml
0 commit comments