|
| 1 | +// Copyright 2022 The WebNN-native Authors |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "ExecutionContextDML.h" |
| 16 | + |
| 17 | +namespace webnn::native::dml { |
| 18 | + |
| 19 | + // An adapter called the "Microsoft Basic Render Driver" is always present. This adapter is a |
| 20 | + // render-only device that has no display outputs. |
| 21 | + inline bool IsSoftwareAdapter(IDXGIAdapter1* pAdapter) { |
| 22 | + DXGI_ADAPTER_DESC1 pDesc; |
| 23 | + pAdapter->GetDesc1(&pDesc); |
| 24 | + // See here for documentation on filtering WARP adapter: |
| 25 | + // https://docs.microsoft.com/en-us/windows/desktop/direct3ddxgi/d3d10-graphics-programming-guide-dxgi#new-info-about-enumerating-adapters-for-windows-8 |
| 26 | + return pDesc.Flags == DXGI_ADAPTER_FLAG_SOFTWARE || |
| 27 | + (pDesc.VendorId == 0x1414 && pDesc.DeviceId == 0x8c); |
| 28 | + } |
| 29 | + |
| 30 | + HRESULT EnumAdapter(DXGI_GPU_PREFERENCE gpuPreference, |
| 31 | + bool useGpu, |
| 32 | + ComPtr<IDXGIAdapter1> adapter) { |
| 33 | + ComPtr<IDXGIFactory6> dxgiFactory; |
| 34 | + RETURN_IF_FAILED(CreateDXGIFactory1(IID_PPV_ARGS(&dxgiFactory))); |
| 35 | + if (useGpu) { |
| 36 | + UINT adapterIndex = 0; |
| 37 | + while (dxgiFactory->EnumAdapterByGpuPreference(adapterIndex++, gpuPreference, |
| 38 | + IID_PPV_ARGS(&adapter)) != |
| 39 | + DXGI_ERROR_NOT_FOUND) { |
| 40 | + if (!IsSoftwareAdapter(adapter.Get())) { |
| 41 | + break; |
| 42 | + } |
| 43 | + } |
| 44 | + } else { |
| 45 | + RETURN_IF_FAILED(dxgiFactory->EnumWarpAdapter(IID_PPV_ARGS(&adapter))); |
| 46 | + } |
| 47 | + return S_OK; |
| 48 | + } |
| 49 | + |
| 50 | + ExecutionContext::ExecutionContext(ComPtr<IDXGIAdapter1> adapter, bool useDebugLayer) |
| 51 | + : mAdapter(std::move(adapter)), mUseDebugLayer(useDebugLayer) { |
| 52 | + } |
| 53 | + |
| 54 | + // static |
| 55 | + std::unique_ptr<ExecutionContext> ExecutionContext::Create(ComPtr<IDXGIAdapter1> adapter, |
| 56 | + bool useDebugLayer) { |
| 57 | + std::unique_ptr<ExecutionContext> executionContext( |
| 58 | + new ExecutionContext(adapter, useDebugLayer)); |
| 59 | + if (FAILED(executionContext->Initialize())) { |
| 60 | + dawn::ErrorLog() << "Failed to initialize Device."; |
| 61 | + return nullptr; |
| 62 | + } |
| 63 | + return executionContext; |
| 64 | + } |
| 65 | + |
| 66 | + HRESULT ExecutionContext::Initialize() { |
| 67 | + if (mUseDebugLayer) { |
| 68 | + ComPtr<ID3D12Debug> debug; |
| 69 | + if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debug)))) { |
| 70 | + debug->EnableDebugLayer(); |
| 71 | + } |
| 72 | + } |
| 73 | + RETURN_IF_FAILED( |
| 74 | + D3D12CreateDevice(mAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&mD3D12Device))); |
| 75 | + D3D12_COMMAND_QUEUE_DESC commandQueueDesc{}; |
| 76 | + commandQueueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; |
| 77 | + commandQueueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; |
| 78 | + RETURN_IF_FAILED( |
| 79 | + mD3D12Device->CreateCommandQueue(&commandQueueDesc, IID_PPV_ARGS(&mCommandQueue))); |
| 80 | + RETURN_IF_FAILED(mD3D12Device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, |
| 81 | + IID_PPV_ARGS(&mCommandAllocator))); |
| 82 | + RETURN_IF_FAILED(mD3D12Device->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_DIRECT, |
| 83 | + mCommandAllocator.Get(), nullptr, |
| 84 | + IID_PPV_ARGS(&mCommandList))); |
| 85 | + |
| 86 | + // Create the DirectML device. |
| 87 | + DML_CREATE_DEVICE_FLAGS dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_NONE; |
| 88 | +#if defined(_DEBUG) |
| 89 | + dmlCreateDeviceFlags = DML_CREATE_DEVICE_FLAG_DEBUG; |
| 90 | +#endif |
| 91 | + if (dmlCreateDeviceFlags == DML_CREATE_DEVICE_FLAG_DEBUG) { |
| 92 | + if (FAILED(DMLCreateDevice(mD3D12Device.Get(), dmlCreateDeviceFlags, |
| 93 | + IID_PPV_ARGS(&mDevice)))) { |
| 94 | + dawn::WarningLog() << "Failed to create a DirectML device with debug flag, " |
| 95 | + "will fall back to use none flag."; |
| 96 | + RETURN_IF_FAILED(DMLCreateDevice(mD3D12Device.Get(), DML_CREATE_DEVICE_FLAG_NONE, |
| 97 | + IID_PPV_ARGS(&mDevice))); |
| 98 | + } |
| 99 | + } else { |
| 100 | + RETURN_IF_FAILED( |
| 101 | + DMLCreateDevice(mD3D12Device.Get(), dmlCreateDeviceFlags, IID_PPV_ARGS(&mDevice))); |
| 102 | + } |
| 103 | + return S_OK; |
| 104 | + }; |
| 105 | + |
| 106 | +} // namespace webnn::native::dml |
0 commit comments