@@ -13,10 +13,14 @@ using namespace pydml;
1313using Microsoft::WRL ::ComPtr;
1414
1515SVDescriptorHeap::SVDescriptorHeap (
16- ComPtr<ID3D12DescriptorHeap> heap,
17- uint64_t size)
18- : gpgmm::d3d12::Heap(heap, DXGI_MEMORY_SEGMENT_GROUP_LOCAL , size),
19- m_Heap(std::move(heap)) {
16+ ComPtr<gpgmm::d3d12::Heap> heap)
17+ : m_Heap(std::move(heap)) {
18+ }
19+
20+ ID3D12DescriptorHeap* SVDescriptorHeap::GetDescriptorHeap () const {
21+ ComPtr<ID3D12DescriptorHeap> descriptorHeap;
22+ m_Heap->As (&descriptorHeap);
23+ return descriptorHeap.Get ();
2024}
2125
2226Device::Device (bool useGpu, bool useDebugLayer, DXGI_GPU_PREFERENCE gpuPreference) : m_useGpu(useGpu), m_useDebugLayer(useDebugLayer), m_gpuPreference(gpuPreference) {}
@@ -101,20 +105,16 @@ HRESULT Device::Init()
101105 nullptr , // initial pipeline state
102106 IID_GRAPHICS_PPV_ARGS (m_commandList.GetAddressOf ())));
103107
104- D3D12_FEATURE_DATA_ARCHITECTURE arch = {};
105- ReturnIfFailed (m_d3d12Device->CheckFeatureSupport (D3D12_FEATURE_ARCHITECTURE , &arch, sizeof (arch)));
106-
107108 D3D12_FEATURE_DATA_D3D12_OPTIONS options = {};
108109 ReturnIfFailed (m_d3d12Device->CheckFeatureSupport (D3D12_FEATURE_D3D12_OPTIONS , &options, sizeof (options)));
109110
110111 gpgmm::d3d12::ALLOCATOR_DESC allocatorDesc = {};
111112 allocatorDesc.Adapter = dxgiAdapter;
112113 allocatorDesc.Device = m_d3d12Device;
113- allocatorDesc.IsUMA = arch.UMA ;
114114 allocatorDesc.ResourceHeapTier = options.ResourceHeapTier ;
115115
116116#ifdef WEBNN_ENABLE_RESOURCE_DUMP
117- allocatorDesc.RecordOptions .Flags |= gpgmm::d3d12::ALLOCATOR_RECORD_FLAG_ALL_EVENTS ;
117+ allocatorDesc.RecordOptions .Flags |= gpgmm::d3d12::EVENT_RECORD_FLAG_ALL_EVENTS ;
118118 allocatorDesc.RecordOptions .MinMessageLevel = D3D12_MESSAGE_SEVERITY_MESSAGE ;
119119 allocatorDesc.RecordOptions .UseDetailedTimingEvents = true ;
120120#endif
@@ -299,8 +299,8 @@ HRESULT Device::DispatchOperator(
299299
300300 DML_BINDING_TABLE_DESC bindingTableDesc = {};
301301 bindingTableDesc.Dispatchable = op;
302- bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->m_Heap ->GetCPUDescriptorHandleForHeapStart ();
303- bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->m_Heap ->GetGPUDescriptorHandleForHeapStart ();
302+ bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->GetDescriptorHeap () ->GetCPUDescriptorHandleForHeapStart ();
303+ bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->GetDescriptorHeap () ->GetGPUDescriptorHandleForHeapStart ();
304304 bindingTableDesc.SizeInDescriptors = bindingProps.RequiredDescriptorCount ;
305305
306306 ReturnIfFailed (m_bindingTable->Reset (&bindingTableDesc));
@@ -339,7 +339,8 @@ HRESULT Device::DispatchOperator(
339339 }
340340
341341 // Record and execute commands, and wait for completion
342- m_commandList->SetDescriptorHeaps (1 , m_descriptorHeap->m_Heap .GetAddressOf ());
342+ ID3D12DescriptorHeap* descriptorHeap = m_descriptorHeap->GetDescriptorHeap ();
343+ m_commandList->SetDescriptorHeaps (1 , &descriptorHeap);
343344 m_commandRecorder->RecordDispatch (m_commandList.Get (), op, m_bindingTable.Get ());
344345 RecordOutputReadBack (outputsResourceSize);
345346 ReturnIfFailed (ExecuteCommandListAndWait ());
@@ -536,8 +537,8 @@ HRESULT Device::InitializeOperator(
536537
537538 DML_BINDING_TABLE_DESC bindingTableDesc = {};
538539 bindingTableDesc.Dispatchable = m_operatorInitializer.Get ();
539- bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->m_Heap ->GetCPUDescriptorHandleForHeapStart ();
540- bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->m_Heap ->GetGPUDescriptorHandleForHeapStart ();
540+ bindingTableDesc.CPUDescriptorHandle = m_descriptorHeap->GetDescriptorHeap () ->GetCPUDescriptorHandleForHeapStart ();
541+ bindingTableDesc.GPUDescriptorHandle = m_descriptorHeap->GetDescriptorHeap () ->GetGPUDescriptorHandleForHeapStart ();
541542 bindingTableDesc.SizeInDescriptors = descriptorHeapSize;
542543
543544 ReturnIfFailed (m_bindingTable->Reset (&bindingTableDesc));
@@ -560,7 +561,8 @@ HRESULT Device::InitializeOperator(
560561 }
561562
562563 // Record and execute commands, and wait for completion
563- m_commandList->SetDescriptorHeaps (1 , m_descriptorHeap->m_Heap .GetAddressOf ());
564+ ID3D12DescriptorHeap* descriptorHeap = m_descriptorHeap->GetDescriptorHeap ();
565+ m_commandList->SetDescriptorHeaps (1 , &descriptorHeap);
564566 m_commandRecorder->RecordDispatch (m_commandList.Get (), m_operatorInitializer.Get (), m_bindingTable.Get ());
565567 ReturnIfFailed (ExecuteCommandListAndWait ());
566568 return S_OK ;
@@ -655,35 +657,42 @@ HRESULT Device::EnsureDefaultBufferSize(uint64_t requestedSizeInBytes, _Inout_ C
655657
656658HRESULT Device::EnsureDescriptorHeapSize (uint32_t requestedSizeInDescriptors)
657659{
658- uint32_t existingSize = m_descriptorHeap ? m_descriptorHeap->m_Heap ->GetDesc ().NumDescriptors : 0 ;
660+ uint32_t existingSize = m_descriptorHeap ? m_descriptorHeap->GetDescriptorHeap () ->GetDesc ().NumDescriptors : 0 ;
659661 uint32_t newSize = RoundUpToPow2 (requestedSizeInDescriptors); // ensures geometric growth
660662
661663 if (newSize != existingSize)
662664 {
663665 if (m_descriptorHeap != nullptr && m_residencyManager != nullptr ){
664- m_residencyManager->UnlockHeap (m_descriptorHeap. get ());
666+ m_residencyManager->UnlockHeap (m_descriptorHeap-> m_Heap . Get ());
665667 }
666668
667669 m_descriptorHeap = nullptr ;
668670
669- if (m_residencyManager != nullptr ){
670- ReturnIfFailed (m_residencyManager->Evict (newSize, DXGI_MEMORY_SEGMENT_GROUP_LOCAL ));
671- }
672-
673- D3D12_DESCRIPTOR_HEAP_DESC desc = {};
674- desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV ;
675- desc.NumDescriptors = newSize;
676- desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE ;
677-
678- ComPtr<ID3D12DescriptorHeap> d3d12DescriptorHeap;
679- ReturnIfFailed (m_d3d12Device->CreateDescriptorHeap (&desc, IID_GRAPHICS_PPV_ARGS (d3d12DescriptorHeap.GetAddressOf ())));
680-
681- m_descriptorHeap = std::make_unique<SVDescriptorHeap>(std::move (d3d12DescriptorHeap), newSize);
671+ auto createHeapFn = [&](ID3D12Pageable** ppPageableOut) -> HRESULT {
672+ ComPtr<ID3D12DescriptorHeap> d3d12DescriptorHeap;
673+ D3D12_DESCRIPTOR_HEAP_DESC desc = {};
674+ desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV ;
675+ desc.NumDescriptors = newSize;
676+ desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE ;
677+ ReturnIfFailed (m_d3d12Device->CreateDescriptorHeap (
678+ &desc, IID_PPV_ARGS (&d3d12DescriptorHeap)));
679+ *ppPageableOut = d3d12DescriptorHeap.Detach ();
680+ return S_OK ;
681+ };
682+
683+ gpgmm::d3d12::HEAP_DESC heapDesc = {};
684+ heapDesc.SizeInBytes = newSize * m_d3d12Device->GetDescriptorHandleIncrementSize (D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV );
685+ heapDesc.MemorySegment = gpgmm::d3d12::RESIDENCY_SEGMENT_LOCAL ;
686+
687+ ComPtr<gpgmm::d3d12::Heap> descriptorHeap;
688+ ReturnIfFailed (gpgmm::d3d12::Heap::CreateHeap (heapDesc, m_residencyManager.Get (), createHeapFn,
689+ &descriptorHeap));
682690
683691 if (m_residencyManager != nullptr ){
684- ReturnIfFailed (m_residencyManager->InsertHeap (m_descriptorHeap.get ()));
685- ReturnIfFailed (m_residencyManager->LockHeap (m_descriptorHeap.get ()));
692+ ReturnIfFailed (m_residencyManager->LockHeap (descriptorHeap.Get ()));
686693 }
694+
695+ m_descriptorHeap = std::make_unique<SVDescriptorHeap>(std::move (descriptorHeap));
687696 }
688697 return S_OK ;
689698}
0 commit comments