From f3899bd12b76f1711a159aa6a6cfd88efb182999 Mon Sep 17 00:00:00 2001 From: Marijn Suijten Date: Tue, 30 Jun 2026 14:48:42 +0200 Subject: [PATCH] [Metal] Make dispatchRays a self-contained encoder operation The DX and Vulkan backends record a ray dispatch by creating an encoder and calling Encoder.dispatchRays() directly. Metal instead went through createRayTracingCommands, a device method that bound the descriptor heap and argument buffer, built and bound the IRDispatchRaysArgument, marked the RT resources resident, and only then called a thin dispatchRays(). Since Metal lowers raygen to a compute kernel, ray tracing now records through the shared createComputeCommands path (which already binds the heap and argument buffer and marks the common resources resident) by branching on P.isRayTracing(). The IRDispatchRaysArgument setup and RT-specific residency move into MTLComputeEncoder::dispatchRays, defined out-of-line so it can allocate via CB->Dev and keep the buffer alive in CB->KeepAliveOwned. The heap address for ResDescHeap, absent from the cross-backend signature, is remembered on the encoder via bindResourceHeap(). createRayTracingCommands is removed. Co-Authored-By: Claude Opus 4.8 (1M context) --- lib/API/MTL/MTLDevice.cpp | 241 ++++++++++++++++---------------------- 1 file changed, 98 insertions(+), 143 deletions(-) diff --git a/lib/API/MTL/MTLDevice.cpp b/lib/API/MTL/MTLDevice.cpp index bf3554bb2..1aeb171fa 100644 --- a/lib/API/MTL/MTLDevice.cpp +++ b/lib/API/MTL/MTLDevice.cpp @@ -609,6 +609,9 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder { /// the next encoder transition (via endEncodingImpl). MTL::AccelerationStructureCommandEncoder *ASEnc = nullptr; + /// Bound resource heap, read back for IRDispatchRaysArgument::ResDescHeap. + MTLDescriptorHeap *BoundResourceHeap = nullptr; + /// Accumulated barrier scope from commands recorded since the last barrier. MTL::BarrierScope PendingScope = MTL::BarrierScope(0); @@ -665,6 +668,14 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder { MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; } + llvm::Error bindResourceHeap(MTLDescriptorHeap &Heap) { + if (auto Err = ensureComputeEncoder()) + return Err; + Heap.bind(ComputeEnc); + BoundResourceHeap = &Heap; + return llvm::Error::success(); + } + MTL::CommandEncoder *getActiveEncoder() const { if (ComputeEnc) return ComputeEnc; @@ -787,40 +798,9 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder { // MTL::Device handle (used to allocate scratch and instance buffers). llvm::Error batchBuildAS(llvm::ArrayRef Items) override; - // Dispatch threads using a raygen compute kernel synthesized by the - // irconverter. All bindings (descriptor heap, top-level argument buffer, - // IRDispatchRaysArgument at slot 3, visible/intersection function tables, - // and the SBT buffer) must already be set on the active compute encoder by - // the caller — this method only binds the pipeline state and issues the - // dispatch. - llvm::Error dispatchRays(const PipelineState &PSO, const ShaderBindingTable &, - uint32_t Width, uint32_t Height, - uint32_t Depth) override { - if (!llvm::isa(&PSO)) - return llvm::createStringError( - std::errc::invalid_argument, - "dispatchRays requires a RayTracing PipelineState."); - const auto &RTPSO = llvm::cast(PSO); - if (!RTPSO.ComputePipeline) - return llvm::createStringError( - std::errc::invalid_argument, - "RayTracing PipelineState has no compute pipeline state."); - if (auto Err = ensureComputeEncoder()) - return Err; - flushBarrier(); - insertDebugSignpost( - llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth) - .str()); - ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline); - - // DispatchRays(W, H, D) launches W*H*D rays; tid in the irconverter raygen - // kernel is the per-ray index. Pass grid as raw (W, H, D) and let Metal - // ceil-divide by ThreadsPerGroup to compute threadgroup count. - const MTL::Size GridSize(Width, Height, Depth); - ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup); - addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures); - return llvm::Error::success(); - } + llvm::Error dispatchRays(const PipelineState &PSO, + const ShaderBindingTable &SBT, uint32_t Width, + uint32_t Height, uint32_t Depth) override; /// Lazily transition into an AccelerationStructureCommandEncoder; mirrors /// the existing compute↔blit lazy switch. @@ -1742,7 +1722,8 @@ class MTLDevice : public offloadtest::Device { const auto &PS = llvm::cast(IS.Pipeline.get()); MTLGPUDescriptorHandle Handle = {}; if (IS.DescHeap) { - IS.DescHeap->bind(NativeEncoder); + if (auto Err = Encoder.bindResourceHeap(*IS.DescHeap)) + return Err; Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart(); } @@ -1772,7 +1753,15 @@ class MTLDevice : public offloadtest::Device { NativeEncoder->useResource(llvm::cast(B.get())->Buf, MTL::ResourceUsageRead); - if (auto Err = Encoder.dispatch(*IS.Pipeline.get(), + if (P.isRayTracing()) { + if (auto Err = + Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(), + P.DispatchParameters.DispatchGroupCount[0], + P.DispatchParameters.DispatchGroupCount[1], + P.DispatchParameters.DispatchGroupCount[2])) + return Err; + } else if (auto Err = + Encoder.dispatch(*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0], P.DispatchParameters.DispatchGroupCount[1], P.DispatchParameters.DispatchGroupCount[2])) @@ -1781,112 +1770,6 @@ class MTLDevice : public offloadtest::Device { return llvm::Error::success(); } - llvm::Error createRayTracingCommands(Pipeline &P, InvocationState &IS) { - auto EncoderOrErr = IS.CB->createComputeEncoder(); - if (!EncoderOrErr) - return EncoderOrErr.takeError(); - auto &Encoder = llvm::cast(*EncoderOrErr.get()); - MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative(); - - const auto &RTPSO = - llvm::cast(*IS.Pipeline.get()); - const auto &SBT = llvm::cast(*IS.SBT.get()); - - // Bind the global descriptor heap + top-level argument buffer the same - // way the compute path does; the raygen kernel and any visible-function - // callees consume them at the same slots (kIRDescriptorHeapBindPoint and - // kIRArgumentBufferBindPoint). - MTLGPUDescriptorHandle Handle = {}; - if (IS.DescHeap) { - IS.DescHeap->bind(NativeEncoder); - Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart(); - } - for (uint32_t Idx = 0u; Idx < P.Sets.size(); ++Idx) { - RTPSO.ArgBuffer->setRootDescriptorTable(Idx, Handle); - Handle.addOffset(P.Sets[Idx].Resources.size()); - } - RTPSO.ArgBuffer->bind(NativeEncoder); - - // Populate the per-dispatch IRDispatchRaysArgument: SBT region addresses - // (RayGen / Miss / HitGroup / Callable), GPU pointers to the global - // root-signature argument buffer + descriptor heaps, plus resource IDs - // for the visible / intersection function tables. The raygen kernel - // reads this struct from the buffer bound at kIRRayDispatchArgumentsBind- - // Point and any visible-function callees inherit it through the same - // pointer. - IRDispatchRaysArgument Args{}; - Args.DispatchRaysDesc.RayGenerationShaderRecord = SBT.RayGenRegion; - Args.DispatchRaysDesc.MissShaderTable = SBT.MissRegion; - Args.DispatchRaysDesc.HitGroupTable = SBT.HitGroupRegion; - Args.DispatchRaysDesc.CallableShaderTable = SBT.CallableRegion; - Args.DispatchRaysDesc.Width = P.DispatchParameters.DispatchGroupCount[0]; - Args.DispatchRaysDesc.Height = P.DispatchParameters.DispatchGroupCount[1]; - Args.DispatchRaysDesc.Depth = P.DispatchParameters.DispatchGroupCount[2]; - Args.GRS = RTPSO.ArgBuffer->getGPUAddress(); - Args.ResDescHeap = - IS.DescHeap ? IS.DescHeap->getGPUDescriptorHandleForHeapStart().Ptr : 0; - Args.SmpDescHeap = 0; - Args.VisibleFunctionTable = - RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0}; - Args.IntersectionFunctionTable = - RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0}; - Args.IntersectionFunctionTables = 0; - - const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer(); - auto ArgsBufOrErr = offloadtest::createBufferWithData( - *IS.CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args, - sizeof(IRDispatchRaysArgument), nullptr, nullptr); - if (!ArgsBufOrErr) - return ArgsBufOrErr.takeError(); - - auto *MTLArgsBuf = llvm::cast(ArgsBufOrErr->get()); - IS.CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr)); - - NativeEncoder->setBuffer(MTLArgsBuf->Buf, 0, - kIRRayDispatchArgumentsBindPoint); - NativeEncoder->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead); - - // Mark every dispatch-side resource resident: descriptor-table bundles, - // acceleration structures + their irconverter header/contribution - // buffers (so RayQuery/TraceRay can read them), the SBT buffer (the - // raygen kernel dereferences SBT addresses), and the visible / - // intersection function tables. - for (const auto &Table : IS.DescTables) - for (const auto &ResPair : Table.Resources) - for (const auto &ResSet : ResPair.second) - NativeEncoder->useResource(ResSet.Resource.get(), - MTL::ResourceUsageRead | - MTL::ResourceUsageWrite); - auto MarkASResident = - [&](std::unique_ptr &AS) { - auto *MTLAS = llvm::cast(AS.get()); - NativeEncoder->useResource(MTLAS->AccelStruct, - MTL::ResourceUsageRead); - }; - for (auto &AS : IS.BLASes) - MarkASResident(AS); - for (auto &Entry : IS.TLASes) - MarkASResident(Entry.second); - for (auto &B : IS.ASDescriptorBuffers) - NativeEncoder->useResource(llvm::cast(B.get())->Buf, - MTL::ResourceUsageRead); - if (SBT.Buffer) - NativeEncoder->useResource(SBT.Buffer, MTL::ResourceUsageRead); - if (RTPSO.VFT) - NativeEncoder->useResource(RTPSO.VFT, MTL::ResourceUsageRead); - if (RTPSO.IFT) - NativeEncoder->useResource(RTPSO.IFT, MTL::ResourceUsageRead); - - if (auto Err = - Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(), - P.DispatchParameters.DispatchGroupCount[0], - P.DispatchParameters.DispatchGroupCount[1], - P.DispatchParameters.DispatchGroupCount[2])) - return Err; - Encoder.endEncoding(); - return llvm::Error::success(); - } - llvm::Error createRenderTarget(Pipeline &P, InvocationState &IS) { if (!P.Bindings.RTargetBufferPtr) return llvm::createStringError( @@ -3094,7 +2977,7 @@ class MTLDevice : public offloadtest::Device { IS.SBT = std::move(*SBTOrErr); llvm::outs() << "Shader Binding Table created.\n"; - if (auto Err = createRayTracingCommands(P, IS)) + if (auto Err = createComputeCommands(P, IS)) return Err; } @@ -3303,6 +3186,78 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef Items) { return llvm::Error::success(); } + +llvm::Error MTLComputeEncoder::dispatchRays(const PipelineState &PSO, + const ShaderBindingTable &SBT, + uint32_t Width, uint32_t Height, + uint32_t Depth) { + if (!llvm::isa(&PSO)) + return llvm::createStringError( + std::errc::invalid_argument, + "dispatchRays requires a RayTracing PipelineState."); + if (!llvm::isa(&SBT)) + return llvm::createStringError( + std::errc::invalid_argument, + "dispatchRays requires a Metal ShaderBindingTable."); + const auto &RTPSO = llvm::cast(PSO); + const auto &MTLSBT = llvm::cast(SBT); + if (!RTPSO.ComputePipeline) + return llvm::createStringError( + std::errc::invalid_argument, + "RayTracing PipelineState has no compute pipeline state."); + if (auto Err = ensureComputeEncoder()) + return Err; + flushBarrier(); + insertDebugSignpost( + llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth).str()); + + // Per-dispatch ray arguments, consumed at kIRRayDispatchArgumentsBindPoint. + IRDispatchRaysArgument Args{}; + Args.DispatchRaysDesc.RayGenerationShaderRecord = MTLSBT.RayGenRegion; + Args.DispatchRaysDesc.MissShaderTable = MTLSBT.MissRegion; + Args.DispatchRaysDesc.HitGroupTable = MTLSBT.HitGroupRegion; + Args.DispatchRaysDesc.CallableShaderTable = MTLSBT.CallableRegion; + Args.DispatchRaysDesc.Width = Width; + Args.DispatchRaysDesc.Height = Height; + Args.DispatchRaysDesc.Depth = Depth; + Args.GRS = RTPSO.ArgBuffer->getGPUAddress(); + Args.ResDescHeap = + BoundResourceHeap + ? BoundResourceHeap->getGPUDescriptorHandleForHeapStart().Ptr + : 0; + Args.SmpDescHeap = 0; + Args.VisibleFunctionTable = + RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0}; + Args.IntersectionFunctionTable = + RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0}; + Args.IntersectionFunctionTables = 0; + + const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer(); + auto ArgsBufOrErr = offloadtest::createBufferWithData( + *CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args, + sizeof(IRDispatchRaysArgument), nullptr, nullptr); + if (!ArgsBufOrErr) + return ArgsBufOrErr.takeError(); + auto *MTLArgsBuf = llvm::cast(ArgsBufOrErr->get()); + CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr)); + + ComputeEnc->setBuffer(MTLArgsBuf->Buf, 0, kIRRayDispatchArgumentsBindPoint); + ComputeEnc->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead); + + if (MTLSBT.Buffer) + ComputeEnc->useResource(MTLSBT.Buffer, MTL::ResourceUsageRead); + if (RTPSO.VFT) + ComputeEnc->useResource(RTPSO.VFT, MTL::ResourceUsageRead); + if (RTPSO.IFT) + ComputeEnc->useResource(RTPSO.IFT, MTL::ResourceUsageRead); + + ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline); + // dispatchThreads automatically takes care of bounds checking. + const MTL::Size GridSize(Width, Height, Depth); + ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup); + addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures); + return llvm::Error::success(); +} } // namespace llvm::Error offloadtest::initializeMetalDevices(