@@ -609,6 +609,10 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
609609 // / the next encoder transition (via endEncodingImpl).
610610 MTL ::AccelerationStructureCommandEncoder *ASEnc = nullptr ;
611611
612+ // / Resource descriptor heap bound to this encoder, if any. Remembered so a
613+ // / ray dispatch can point IRDispatchRaysArgument::ResDescHeap at it.
614+ MTLDescriptorHeap *BoundResourceHeap = nullptr ;
615+
612616 // / Accumulated barrier scope from commands recorded since the last barrier.
613617 MTL ::BarrierScope PendingScope = MTL ::BarrierScope(0 );
614618
@@ -665,6 +669,16 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
665669
666670 MTL ::ComputeCommandEncoder *getNative () const { return ComputeEnc; }
667671
672+ // / Bind \p Heap as the global resource descriptor heap and remember it so a
673+ // / subsequent dispatchRays can reference it from IRDispatchRaysArgument.
674+ llvm::Error bindResourceHeap (MTLDescriptorHeap &Heap) {
675+ if (auto Err = ensureComputeEncoder ())
676+ return Err;
677+ Heap.bind (ComputeEnc);
678+ BoundResourceHeap = &Heap;
679+ return llvm::Error::success ();
680+ }
681+
668682 MTL ::CommandEncoder *getActiveEncoder () const {
669683 if (ComputeEnc)
670684 return ComputeEnc;
@@ -787,40 +801,17 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
787801 // MTL::Device handle (used to allocate scratch and instance buffers).
788802 llvm::Error batchBuildAS (llvm::ArrayRef<ASBuildItem> Items) override ;
789803
790- // Dispatch threads using a raygen compute kernel synthesized by the
791- // irconverter. All bindings (descriptor heap, top-level argument buffer,
792- // IRDispatchRaysArgument at slot 3, visible/intersection function tables,
793- // and the SBT buffer) must already be set on the active compute encoder by
794- // the caller — this method only binds the pipeline state and issues the
795- // dispatch.
796- llvm::Error dispatchRays (const PipelineState &PSO , const ShaderBindingTable &,
797- uint32_t Width, uint32_t Height,
798- uint32_t Depth) override {
799- if (!llvm::isa<MTLRayTracingPipelineState>(&PSO ))
800- return llvm::createStringError (
801- std::errc::invalid_argument,
802- " dispatchRays requires a RayTracing PipelineState." );
803- const auto &RTPSO = llvm::cast<MTLRayTracingPipelineState>(PSO );
804- if (!RTPSO .ComputePipeline )
805- return llvm::createStringError (
806- std::errc::invalid_argument,
807- " RayTracing PipelineState has no compute pipeline state." );
808- if (auto Err = ensureComputeEncoder ())
809- return Err;
810- flushBarrier ();
811- insertDebugSignpost (
812- llvm::formatv (" DispatchRays [{0},{1},{2}]" , Width, Height, Depth)
813- .str ());
814- ComputeEnc->setComputePipelineState (RTPSO .ComputePipeline );
815-
816- // DispatchRays(W, H, D) launches W*H*D rays; tid in the irconverter raygen
817- // kernel is the per-ray index. Pass grid as raw (W, H, D) and let Metal
818- // ceil-divide by ThreadsPerGroup to compute threadgroup count.
819- const MTL ::Size GridSize (Width, Height, Depth);
820- ComputeEnc->dispatchThreads (GridSize, RTPSO .ThreadsPerGroup );
821- addBarrierScope (MTL ::BarrierScopeBuffers | MTL ::BarrierScopeTextures);
822- return llvm::Error::success ();
823- }
804+ // Issue a ray dispatch: synthesizes the per-dispatch IRDispatchRaysArgument
805+ // (SBT region addresses, GRS + resource-heap pointers, visible/intersection
806+ // function-table IDs), binds it at kIRRayDispatchArgumentsBindPoint, marks
807+ // the RT-specific resources resident, and dispatches the irconverter raygen
808+ // compute kernel. The caller must have already bound the descriptor heap and
809+ // top-level argument buffer (the compute-dispatch path does the same).
810+ // Defined out-of-line below — allocating the argument buffer needs
811+ // MTLDevice's full type.
812+ llvm::Error dispatchRays (const PipelineState &PSO ,
813+ const ShaderBindingTable &SBT , uint32_t Width,
814+ uint32_t Height, uint32_t Depth) override ;
824815
825816 // / Lazily transition into an AccelerationStructureCommandEncoder; mirrors
826817 // / the existing compute↔blit lazy switch.
@@ -1742,7 +1733,8 @@ class MTLDevice : public offloadtest::Device {
17421733 const auto &PS = llvm::cast<MTLPipelineState>(IS .Pipeline .get ());
17431734 MTLGPUDescriptorHandle Handle = {};
17441735 if (IS .DescHeap ) {
1745- IS .DescHeap ->bind (NativeEncoder);
1736+ if (auto Err = Encoder.bindResourceHeap (*IS .DescHeap ))
1737+ return Err;
17461738 Handle = IS .DescHeap ->getGPUDescriptorHandleForHeapStart ();
17471739 }
17481740
@@ -1772,7 +1764,18 @@ class MTLDevice : public offloadtest::Device {
17721764 NativeEncoder->useResource (llvm::cast<MTLBuffer>(B.get ())->Buf ,
17731765 MTL ::ResourceUsageRead);
17741766
1775- if (auto Err = Encoder.dispatch (*IS .Pipeline .get (),
1767+ // Metal compiles raygen as a compute kernel, so a ray-tracing pipeline
1768+ // dispatches through the same compute encoder as a plain compute shader;
1769+ // dispatchRays adds the RT-specific argument buffer and residency.
1770+ if (P.isRayTracing ()) {
1771+ if (auto Err =
1772+ Encoder.dispatchRays (*IS .Pipeline .get (), *IS .SBT .get (),
1773+ P.DispatchParameters .DispatchGroupCount [0 ],
1774+ P.DispatchParameters .DispatchGroupCount [1 ],
1775+ P.DispatchParameters .DispatchGroupCount [2 ]))
1776+ return Err;
1777+ } else if (auto Err =
1778+ Encoder.dispatch (*IS .Pipeline .get (),
17761779 P.DispatchParameters .DispatchGroupCount [0 ],
17771780 P.DispatchParameters .DispatchGroupCount [1 ],
17781781 P.DispatchParameters .DispatchGroupCount [2 ]))
@@ -1781,112 +1784,6 @@ class MTLDevice : public offloadtest::Device {
17811784 return llvm::Error::success ();
17821785 }
17831786
1784- llvm::Error createRayTracingCommands (Pipeline &P, InvocationState &IS ) {
1785- auto EncoderOrErr = IS .CB ->createComputeEncoder ();
1786- if (!EncoderOrErr)
1787- return EncoderOrErr.takeError ();
1788- auto &Encoder = llvm::cast<MTLComputeEncoder>(*EncoderOrErr.get ());
1789- MTL ::ComputeCommandEncoder *NativeEncoder = Encoder.getNative ();
1790-
1791- const auto &RTPSO =
1792- llvm::cast<MTLRayTracingPipelineState>(*IS .Pipeline .get ());
1793- const auto &SBT = llvm::cast<MTLShaderBindingTable>(*IS .SBT .get ());
1794-
1795- // Bind the global descriptor heap + top-level argument buffer the same
1796- // way the compute path does; the raygen kernel and any visible-function
1797- // callees consume them at the same slots (kIRDescriptorHeapBindPoint and
1798- // kIRArgumentBufferBindPoint).
1799- MTLGPUDescriptorHandle Handle = {};
1800- if (IS .DescHeap ) {
1801- IS .DescHeap ->bind (NativeEncoder);
1802- Handle = IS .DescHeap ->getGPUDescriptorHandleForHeapStart ();
1803- }
1804- for (uint32_t Idx = 0u ; Idx < P.Sets .size (); ++Idx) {
1805- RTPSO .ArgBuffer ->setRootDescriptorTable (Idx, Handle);
1806- Handle.addOffset (P.Sets [Idx].Resources .size ());
1807- }
1808- RTPSO .ArgBuffer ->bind (NativeEncoder);
1809-
1810- // Populate the per-dispatch IRDispatchRaysArgument: SBT region addresses
1811- // (RayGen / Miss / HitGroup / Callable), GPU pointers to the global
1812- // root-signature argument buffer + descriptor heaps, plus resource IDs
1813- // for the visible / intersection function tables. The raygen kernel
1814- // reads this struct from the buffer bound at kIRRayDispatchArgumentsBind-
1815- // Point and any visible-function callees inherit it through the same
1816- // pointer.
1817- IRDispatchRaysArgument Args{};
1818- Args.DispatchRaysDesc .RayGenerationShaderRecord = SBT .RayGenRegion ;
1819- Args.DispatchRaysDesc .MissShaderTable = SBT .MissRegion ;
1820- Args.DispatchRaysDesc .HitGroupTable = SBT .HitGroupRegion ;
1821- Args.DispatchRaysDesc .CallableShaderTable = SBT .CallableRegion ;
1822- Args.DispatchRaysDesc .Width = P.DispatchParameters .DispatchGroupCount [0 ];
1823- Args.DispatchRaysDesc .Height = P.DispatchParameters .DispatchGroupCount [1 ];
1824- Args.DispatchRaysDesc .Depth = P.DispatchParameters .DispatchGroupCount [2 ];
1825- Args.GRS = RTPSO .ArgBuffer ->getGPUAddress ();
1826- Args.ResDescHeap =
1827- IS .DescHeap ? IS .DescHeap ->getGPUDescriptorHandleForHeapStart ().Ptr : 0 ;
1828- Args.SmpDescHeap = 0 ;
1829- Args.VisibleFunctionTable =
1830- RTPSO .VFT ? RTPSO .VFT ->gpuResourceID () : MTL ::ResourceID{0 };
1831- Args.IntersectionFunctionTable =
1832- RTPSO .IFT ? RTPSO .IFT ->gpuResourceID () : MTL ::ResourceID{0 };
1833- Args.IntersectionFunctionTables = 0 ;
1834-
1835- const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer ();
1836- auto ArgsBufOrErr = offloadtest::createBufferWithData (
1837- *IS .CB ->Dev , " MTL Dispatch Rays Arguments" , ArgsBufDesc, &Args,
1838- sizeof (IRDispatchRaysArgument), nullptr , nullptr );
1839- if (!ArgsBufOrErr)
1840- return ArgsBufOrErr.takeError ();
1841-
1842- auto *MTLArgsBuf = llvm::cast<MTLBuffer>(ArgsBufOrErr->get ());
1843- IS .CB ->KeepAliveOwned .push_back (std::move (*ArgsBufOrErr));
1844-
1845- NativeEncoder->setBuffer (MTLArgsBuf->Buf , 0 ,
1846- kIRRayDispatchArgumentsBindPoint );
1847- NativeEncoder->useResource (MTLArgsBuf->Buf , MTL ::ResourceUsageRead);
1848-
1849- // Mark every dispatch-side resource resident: descriptor-table bundles,
1850- // acceleration structures + their irconverter header/contribution
1851- // buffers (so RayQuery/TraceRay can read them), the SBT buffer (the
1852- // raygen kernel dereferences SBT addresses), and the visible /
1853- // intersection function tables.
1854- for (const auto &Table : IS .DescTables )
1855- for (const auto &ResPair : Table.Resources )
1856- for (const auto &ResSet : ResPair.second )
1857- NativeEncoder->useResource (ResSet.Resource .get (),
1858- MTL ::ResourceUsageRead |
1859- MTL ::ResourceUsageWrite);
1860- auto MarkASResident =
1861- [&](std::unique_ptr<offloadtest::AccelerationStructure> &AS ) {
1862- auto *MTLAS = llvm::cast<MetalAccelerationStructure>(AS .get ());
1863- NativeEncoder->useResource (MTLAS ->AccelStruct ,
1864- MTL ::ResourceUsageRead);
1865- };
1866- for (auto &AS : IS .BLASes )
1867- MarkASResident (AS );
1868- for (auto &Entry : IS .TLASes )
1869- MarkASResident (Entry.second );
1870- for (auto &B : IS .ASDescriptorBuffers )
1871- NativeEncoder->useResource (llvm::cast<MTLBuffer>(B.get ())->Buf ,
1872- MTL ::ResourceUsageRead);
1873- if (SBT .Buffer )
1874- NativeEncoder->useResource (SBT .Buffer , MTL ::ResourceUsageRead);
1875- if (RTPSO .VFT )
1876- NativeEncoder->useResource (RTPSO .VFT , MTL ::ResourceUsageRead);
1877- if (RTPSO .IFT )
1878- NativeEncoder->useResource (RTPSO .IFT , MTL ::ResourceUsageRead);
1879-
1880- if (auto Err =
1881- Encoder.dispatchRays (*IS .Pipeline .get (), *IS .SBT .get (),
1882- P.DispatchParameters .DispatchGroupCount [0 ],
1883- P.DispatchParameters .DispatchGroupCount [1 ],
1884- P.DispatchParameters .DispatchGroupCount [2 ]))
1885- return Err;
1886- Encoder.endEncoding ();
1887- return llvm::Error::success ();
1888- }
1889-
18901787 llvm::Error createRenderTarget (Pipeline &P, InvocationState &IS ) {
18911788 if (!P.Bindings .RTargetBufferPtr )
18921789 return llvm::createStringError (
@@ -3094,7 +2991,9 @@ class MTLDevice : public offloadtest::Device {
30942991 IS .SBT = std::move (*SBTOrErr);
30952992 llvm::outs () << " Shader Binding Table created.\n " ;
30962993
3097- if (auto Err = createRayTracingCommands (P, IS ))
2994+ // Metal lowers raygen to a compute kernel, so ray tracing records its
2995+ // dispatch through the shared compute-command path.
2996+ if (auto Err = createComputeCommands (P, IS ))
30982997 return Err;
30992998 }
31002999
@@ -3303,6 +3202,88 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) {
33033202
33043203 return llvm::Error::success ();
33053204}
3205+
3206+ llvm::Error MTLComputeEncoder::dispatchRays (const PipelineState &PSO ,
3207+ const ShaderBindingTable &SBT ,
3208+ uint32_t Width, uint32_t Height,
3209+ uint32_t Depth) {
3210+ if (!llvm::isa<MTLRayTracingPipelineState>(&PSO ))
3211+ return llvm::createStringError (
3212+ std::errc::invalid_argument,
3213+ " dispatchRays requires a RayTracing PipelineState." );
3214+ if (!llvm::isa<MTLShaderBindingTable>(&SBT ))
3215+ return llvm::createStringError (
3216+ std::errc::invalid_argument,
3217+ " dispatchRays requires a Metal ShaderBindingTable." );
3218+ const auto &RTPSO = llvm::cast<MTLRayTracingPipelineState>(PSO );
3219+ const auto &MTLSBT = llvm::cast<MTLShaderBindingTable>(SBT );
3220+ if (!RTPSO .ComputePipeline )
3221+ return llvm::createStringError (
3222+ std::errc::invalid_argument,
3223+ " RayTracing PipelineState has no compute pipeline state." );
3224+ if (auto Err = ensureComputeEncoder ())
3225+ return Err;
3226+ flushBarrier ();
3227+ insertDebugSignpost (
3228+ llvm::formatv (" DispatchRays [{0},{1},{2}]" , Width, Height, Depth).str ());
3229+
3230+ // Populate the per-dispatch IRDispatchRaysArgument: SBT region addresses
3231+ // (RayGen / Miss / HitGroup / Callable), GPU pointers to the global
3232+ // root-signature argument buffer + descriptor heaps, plus resource IDs for
3233+ // the visible / intersection function tables. The raygen kernel reads this
3234+ // struct from the buffer bound at kIRRayDispatchArgumentsBindPoint and any
3235+ // visible-function callees inherit it through the same pointer.
3236+ IRDispatchRaysArgument Args{};
3237+ Args.DispatchRaysDesc .RayGenerationShaderRecord = MTLSBT .RayGenRegion ;
3238+ Args.DispatchRaysDesc .MissShaderTable = MTLSBT .MissRegion ;
3239+ Args.DispatchRaysDesc .HitGroupTable = MTLSBT .HitGroupRegion ;
3240+ Args.DispatchRaysDesc .CallableShaderTable = MTLSBT .CallableRegion ;
3241+ Args.DispatchRaysDesc .Width = Width;
3242+ Args.DispatchRaysDesc .Height = Height;
3243+ Args.DispatchRaysDesc .Depth = Depth;
3244+ Args.GRS = RTPSO .ArgBuffer ->getGPUAddress ();
3245+ Args.ResDescHeap =
3246+ BoundResourceHeap
3247+ ? BoundResourceHeap->getGPUDescriptorHandleForHeapStart ().Ptr
3248+ : 0 ;
3249+ Args.SmpDescHeap = 0 ;
3250+ Args.VisibleFunctionTable =
3251+ RTPSO .VFT ? RTPSO .VFT ->gpuResourceID () : MTL ::ResourceID{0 };
3252+ Args.IntersectionFunctionTable =
3253+ RTPSO .IFT ? RTPSO .IFT ->gpuResourceID () : MTL ::ResourceID{0 };
3254+ Args.IntersectionFunctionTables = 0 ;
3255+
3256+ const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer ();
3257+ auto ArgsBufOrErr = offloadtest::createBufferWithData (
3258+ *CB ->Dev , " MTL Dispatch Rays Arguments" , ArgsBufDesc, &Args,
3259+ sizeof (IRDispatchRaysArgument), nullptr , nullptr );
3260+ if (!ArgsBufOrErr)
3261+ return ArgsBufOrErr.takeError ();
3262+ auto *MTLArgsBuf = llvm::cast<MTLBuffer>(ArgsBufOrErr->get ());
3263+ CB ->KeepAliveOwned .push_back (std::move (*ArgsBufOrErr));
3264+
3265+ ComputeEnc->setBuffer (MTLArgsBuf->Buf , 0 , kIRRayDispatchArgumentsBindPoint );
3266+ ComputeEnc->useResource (MTLArgsBuf->Buf , MTL ::ResourceUsageRead);
3267+
3268+ // RT-specific residency beyond what the shared compute path marks: the SBT
3269+ // buffer (the raygen kernel dereferences SBT addresses) and the visible /
3270+ // intersection function tables.
3271+ if (MTLSBT .Buffer )
3272+ ComputeEnc->useResource (MTLSBT .Buffer , MTL ::ResourceUsageRead);
3273+ if (RTPSO .VFT )
3274+ ComputeEnc->useResource (RTPSO .VFT , MTL ::ResourceUsageRead);
3275+ if (RTPSO .IFT )
3276+ ComputeEnc->useResource (RTPSO .IFT , MTL ::ResourceUsageRead);
3277+
3278+ ComputeEnc->setComputePipelineState (RTPSO .ComputePipeline );
3279+ // DispatchRays(W, H, D) launches W*H*D rays; tid in the irconverter raygen
3280+ // kernel is the per-ray index. Pass grid as raw (W, H, D) and let Metal
3281+ // ceil-divide by ThreadsPerGroup to compute threadgroup count.
3282+ const MTL ::Size GridSize (Width, Height, Depth);
3283+ ComputeEnc->dispatchThreads (GridSize, RTPSO .ThreadsPerGroup );
3284+ addBarrierScope (MTL ::BarrierScopeBuffers | MTL ::BarrierScopeTextures);
3285+ return llvm::Error::success ();
3286+ }
33063287} // namespace
33073288
33083289llvm::Error offloadtest::initializeMetalDevices (
0 commit comments