diff --git a/lib/API/MTL/MTLDevice.cpp b/lib/API/MTL/MTLDevice.cpp index bf3554bb2..1aeb171fa 100644 --- a/lib/API/MTL/MTLDevice.cpp +++ b/lib/API/MTL/MTLDevice.cpp @@ -609,6 +609,9 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder { /// the next encoder transition (via endEncodingImpl). MTL::AccelerationStructureCommandEncoder *ASEnc = nullptr; + /// Bound resource heap, read back for IRDispatchRaysArgument::ResDescHeap. + MTLDescriptorHeap *BoundResourceHeap = nullptr; + /// Accumulated barrier scope from commands recorded since the last barrier. MTL::BarrierScope PendingScope = MTL::BarrierScope(0); @@ -665,6 +668,14 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder { MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; } + llvm::Error bindResourceHeap(MTLDescriptorHeap &Heap) { + if (auto Err = ensureComputeEncoder()) + return Err; + Heap.bind(ComputeEnc); + BoundResourceHeap = &Heap; + return llvm::Error::success(); + } + MTL::CommandEncoder *getActiveEncoder() const { if (ComputeEnc) return ComputeEnc; @@ -787,40 +798,9 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder { // MTL::Device handle (used to allocate scratch and instance buffers). llvm::Error batchBuildAS(llvm::ArrayRef Items) override; - // Dispatch threads using a raygen compute kernel synthesized by the - // irconverter. All bindings (descriptor heap, top-level argument buffer, - // IRDispatchRaysArgument at slot 3, visible/intersection function tables, - // and the SBT buffer) must already be set on the active compute encoder by - // the caller — this method only binds the pipeline state and issues the - // dispatch. - llvm::Error dispatchRays(const PipelineState &PSO, const ShaderBindingTable &, - uint32_t Width, uint32_t Height, - uint32_t Depth) override { - if (!llvm::isa(&PSO)) - return llvm::createStringError( - std::errc::invalid_argument, - "dispatchRays requires a RayTracing PipelineState."); - const auto &RTPSO = llvm::cast(PSO); - if (!RTPSO.ComputePipeline) - return llvm::createStringError( - std::errc::invalid_argument, - "RayTracing PipelineState has no compute pipeline state."); - if (auto Err = ensureComputeEncoder()) - return Err; - flushBarrier(); - insertDebugSignpost( - llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth) - .str()); - ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline); - - // DispatchRays(W, H, D) launches W*H*D rays; tid in the irconverter raygen - // kernel is the per-ray index. Pass grid as raw (W, H, D) and let Metal - // ceil-divide by ThreadsPerGroup to compute threadgroup count. - const MTL::Size GridSize(Width, Height, Depth); - ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup); - addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures); - return llvm::Error::success(); - } + llvm::Error dispatchRays(const PipelineState &PSO, + const ShaderBindingTable &SBT, uint32_t Width, + uint32_t Height, uint32_t Depth) override; /// Lazily transition into an AccelerationStructureCommandEncoder; mirrors /// the existing compute↔blit lazy switch. @@ -1742,7 +1722,8 @@ class MTLDevice : public offloadtest::Device { const auto &PS = llvm::cast(IS.Pipeline.get()); MTLGPUDescriptorHandle Handle = {}; if (IS.DescHeap) { - IS.DescHeap->bind(NativeEncoder); + if (auto Err = Encoder.bindResourceHeap(*IS.DescHeap)) + return Err; Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart(); } @@ -1772,7 +1753,15 @@ class MTLDevice : public offloadtest::Device { NativeEncoder->useResource(llvm::cast(B.get())->Buf, MTL::ResourceUsageRead); - if (auto Err = Encoder.dispatch(*IS.Pipeline.get(), + if (P.isRayTracing()) { + if (auto Err = + Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(), + P.DispatchParameters.DispatchGroupCount[0], + P.DispatchParameters.DispatchGroupCount[1], + P.DispatchParameters.DispatchGroupCount[2])) + return Err; + } else if (auto Err = + Encoder.dispatch(*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0], P.DispatchParameters.DispatchGroupCount[1], P.DispatchParameters.DispatchGroupCount[2])) @@ -1781,112 +1770,6 @@ class MTLDevice : public offloadtest::Device { return llvm::Error::success(); } - llvm::Error createRayTracingCommands(Pipeline &P, InvocationState &IS) { - auto EncoderOrErr = IS.CB->createComputeEncoder(); - if (!EncoderOrErr) - return EncoderOrErr.takeError(); - auto &Encoder = llvm::cast(*EncoderOrErr.get()); - MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative(); - - const auto &RTPSO = - llvm::cast(*IS.Pipeline.get()); - const auto &SBT = llvm::cast(*IS.SBT.get()); - - // Bind the global descriptor heap + top-level argument buffer the same - // way the compute path does; the raygen kernel and any visible-function - // callees consume them at the same slots (kIRDescriptorHeapBindPoint and - // kIRArgumentBufferBindPoint). - MTLGPUDescriptorHandle Handle = {}; - if (IS.DescHeap) { - IS.DescHeap->bind(NativeEncoder); - Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart(); - } - for (uint32_t Idx = 0u; Idx < P.Sets.size(); ++Idx) { - RTPSO.ArgBuffer->setRootDescriptorTable(Idx, Handle); - Handle.addOffset(P.Sets[Idx].Resources.size()); - } - RTPSO.ArgBuffer->bind(NativeEncoder); - - // Populate the per-dispatch IRDispatchRaysArgument: SBT region addresses - // (RayGen / Miss / HitGroup / Callable), GPU pointers to the global - // root-signature argument buffer + descriptor heaps, plus resource IDs - // for the visible / intersection function tables. The raygen kernel - // reads this struct from the buffer bound at kIRRayDispatchArgumentsBind- - // Point and any visible-function callees inherit it through the same - // pointer. - IRDispatchRaysArgument Args{}; - Args.DispatchRaysDesc.RayGenerationShaderRecord = SBT.RayGenRegion; - Args.DispatchRaysDesc.MissShaderTable = SBT.MissRegion; - Args.DispatchRaysDesc.HitGroupTable = SBT.HitGroupRegion; - Args.DispatchRaysDesc.CallableShaderTable = SBT.CallableRegion; - Args.DispatchRaysDesc.Width = P.DispatchParameters.DispatchGroupCount[0]; - Args.DispatchRaysDesc.Height = P.DispatchParameters.DispatchGroupCount[1]; - Args.DispatchRaysDesc.Depth = P.DispatchParameters.DispatchGroupCount[2]; - Args.GRS = RTPSO.ArgBuffer->getGPUAddress(); - Args.ResDescHeap = - IS.DescHeap ? IS.DescHeap->getGPUDescriptorHandleForHeapStart().Ptr : 0; - Args.SmpDescHeap = 0; - Args.VisibleFunctionTable = - RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0}; - Args.IntersectionFunctionTable = - RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0}; - Args.IntersectionFunctionTables = 0; - - const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer(); - auto ArgsBufOrErr = offloadtest::createBufferWithData( - *IS.CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args, - sizeof(IRDispatchRaysArgument), nullptr, nullptr); - if (!ArgsBufOrErr) - return ArgsBufOrErr.takeError(); - - auto *MTLArgsBuf = llvm::cast(ArgsBufOrErr->get()); - IS.CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr)); - - NativeEncoder->setBuffer(MTLArgsBuf->Buf, 0, - kIRRayDispatchArgumentsBindPoint); - NativeEncoder->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead); - - // Mark every dispatch-side resource resident: descriptor-table bundles, - // acceleration structures + their irconverter header/contribution - // buffers (so RayQuery/TraceRay can read them), the SBT buffer (the - // raygen kernel dereferences SBT addresses), and the visible / - // intersection function tables. - for (const auto &Table : IS.DescTables) - for (const auto &ResPair : Table.Resources) - for (const auto &ResSet : ResPair.second) - NativeEncoder->useResource(ResSet.Resource.get(), - MTL::ResourceUsageRead | - MTL::ResourceUsageWrite); - auto MarkASResident = - [&](std::unique_ptr &AS) { - auto *MTLAS = llvm::cast(AS.get()); - NativeEncoder->useResource(MTLAS->AccelStruct, - MTL::ResourceUsageRead); - }; - for (auto &AS : IS.BLASes) - MarkASResident(AS); - for (auto &Entry : IS.TLASes) - MarkASResident(Entry.second); - for (auto &B : IS.ASDescriptorBuffers) - NativeEncoder->useResource(llvm::cast(B.get())->Buf, - MTL::ResourceUsageRead); - if (SBT.Buffer) - NativeEncoder->useResource(SBT.Buffer, MTL::ResourceUsageRead); - if (RTPSO.VFT) - NativeEncoder->useResource(RTPSO.VFT, MTL::ResourceUsageRead); - if (RTPSO.IFT) - NativeEncoder->useResource(RTPSO.IFT, MTL::ResourceUsageRead); - - if (auto Err = - Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(), - P.DispatchParameters.DispatchGroupCount[0], - P.DispatchParameters.DispatchGroupCount[1], - P.DispatchParameters.DispatchGroupCount[2])) - return Err; - Encoder.endEncoding(); - return llvm::Error::success(); - } - llvm::Error createRenderTarget(Pipeline &P, InvocationState &IS) { if (!P.Bindings.RTargetBufferPtr) return llvm::createStringError( @@ -3094,7 +2977,7 @@ class MTLDevice : public offloadtest::Device { IS.SBT = std::move(*SBTOrErr); llvm::outs() << "Shader Binding Table created.\n"; - if (auto Err = createRayTracingCommands(P, IS)) + if (auto Err = createComputeCommands(P, IS)) return Err; } @@ -3303,6 +3186,78 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef Items) { return llvm::Error::success(); } + +llvm::Error MTLComputeEncoder::dispatchRays(const PipelineState &PSO, + const ShaderBindingTable &SBT, + uint32_t Width, uint32_t Height, + uint32_t Depth) { + if (!llvm::isa(&PSO)) + return llvm::createStringError( + std::errc::invalid_argument, + "dispatchRays requires a RayTracing PipelineState."); + if (!llvm::isa(&SBT)) + return llvm::createStringError( + std::errc::invalid_argument, + "dispatchRays requires a Metal ShaderBindingTable."); + const auto &RTPSO = llvm::cast(PSO); + const auto &MTLSBT = llvm::cast(SBT); + if (!RTPSO.ComputePipeline) + return llvm::createStringError( + std::errc::invalid_argument, + "RayTracing PipelineState has no compute pipeline state."); + if (auto Err = ensureComputeEncoder()) + return Err; + flushBarrier(); + insertDebugSignpost( + llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth).str()); + + // Per-dispatch ray arguments, consumed at kIRRayDispatchArgumentsBindPoint. + IRDispatchRaysArgument Args{}; + Args.DispatchRaysDesc.RayGenerationShaderRecord = MTLSBT.RayGenRegion; + Args.DispatchRaysDesc.MissShaderTable = MTLSBT.MissRegion; + Args.DispatchRaysDesc.HitGroupTable = MTLSBT.HitGroupRegion; + Args.DispatchRaysDesc.CallableShaderTable = MTLSBT.CallableRegion; + Args.DispatchRaysDesc.Width = Width; + Args.DispatchRaysDesc.Height = Height; + Args.DispatchRaysDesc.Depth = Depth; + Args.GRS = RTPSO.ArgBuffer->getGPUAddress(); + Args.ResDescHeap = + BoundResourceHeap + ? BoundResourceHeap->getGPUDescriptorHandleForHeapStart().Ptr + : 0; + Args.SmpDescHeap = 0; + Args.VisibleFunctionTable = + RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0}; + Args.IntersectionFunctionTable = + RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0}; + Args.IntersectionFunctionTables = 0; + + const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer(); + auto ArgsBufOrErr = offloadtest::createBufferWithData( + *CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args, + sizeof(IRDispatchRaysArgument), nullptr, nullptr); + if (!ArgsBufOrErr) + return ArgsBufOrErr.takeError(); + auto *MTLArgsBuf = llvm::cast(ArgsBufOrErr->get()); + CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr)); + + ComputeEnc->setBuffer(MTLArgsBuf->Buf, 0, kIRRayDispatchArgumentsBindPoint); + ComputeEnc->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead); + + if (MTLSBT.Buffer) + ComputeEnc->useResource(MTLSBT.Buffer, MTL::ResourceUsageRead); + if (RTPSO.VFT) + ComputeEnc->useResource(RTPSO.VFT, MTL::ResourceUsageRead); + if (RTPSO.IFT) + ComputeEnc->useResource(RTPSO.IFT, MTL::ResourceUsageRead); + + ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline); + // dispatchThreads automatically takes care of bounds checking. + const MTL::Size GridSize(Width, Height, Depth); + ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup); + addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures); + return llvm::Error::success(); +} } // namespace llvm::Error offloadtest::initializeMetalDevices(