Skip to content

Commit 176f60b

Browse files
MarijnS95claude
andcommitted
[Metal] Make dispatchRays a self-contained encoder operation
createRayTracingCommands was a backend-private device method that bound the descriptor heap and top-level argument buffer, synthesized and bound the IRDispatchRaysArgument, marked RT resources resident, and only then called the encoder's thin dispatchRays(). The DX and Vulkan backends instead create an encoder and call Encoder.dispatchRays() directly. Mirror that here. Since Metal lowers raygen to a compute kernel, route ray tracing through the shared createComputeCommands path, which already binds the descriptor heap / argument buffer and marks the common resources resident, and now branches on P.isRayTracing(). The IRDispatchRaysArgument construction, upload, binding, and RT-specific residency (SBT buffer, visible/intersection function tables, the argument buffer itself) move into MTLComputeEncoder::dispatchRays, which is defined out-of-line so it can allocate via CB->Dev and keep the buffer alive in CB->KeepAliveOwned, like batchBuildAS. IRDispatchRaysArgument::ResDescHeap needs the resource descriptor heap address, which is not in the cross-backend dispatchRays signature, so the encoder remembers the heap bound to it via bindResourceHeap() and reads its GPU address back when dispatching. createRayTracingCommands is deleted. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 013ce93 commit 176f60b

1 file changed

Lines changed: 124 additions & 143 deletions

File tree

lib/API/MTL/MTLDevice.cpp

Lines changed: 124 additions & 143 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,10 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
609609
/// the next encoder transition (via endEncodingImpl).
610610
MTL::AccelerationStructureCommandEncoder *ASEnc = nullptr;
611611

612+
/// Resource descriptor heap bound to this encoder, if any. Remembered so a
613+
/// ray dispatch can point IRDispatchRaysArgument::ResDescHeap at it.
614+
MTLDescriptorHeap *BoundResourceHeap = nullptr;
615+
612616
/// Accumulated barrier scope from commands recorded since the last barrier.
613617
MTL::BarrierScope PendingScope = MTL::BarrierScope(0);
614618

@@ -665,6 +669,16 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
665669

666670
MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; }
667671

672+
/// Bind \p Heap as the global resource descriptor heap and remember it so a
673+
/// subsequent dispatchRays can reference it from IRDispatchRaysArgument.
674+
llvm::Error bindResourceHeap(MTLDescriptorHeap &Heap) {
675+
if (auto Err = ensureComputeEncoder())
676+
return Err;
677+
Heap.bind(ComputeEnc);
678+
BoundResourceHeap = &Heap;
679+
return llvm::Error::success();
680+
}
681+
668682
MTL::CommandEncoder *getActiveEncoder() const {
669683
if (ComputeEnc)
670684
return ComputeEnc;
@@ -787,40 +801,17 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
787801
// MTL::Device handle (used to allocate scratch and instance buffers).
788802
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
789803

790-
// Dispatch threads using a raygen compute kernel synthesized by the
791-
// irconverter. All bindings (descriptor heap, top-level argument buffer,
792-
// IRDispatchRaysArgument at slot 3, visible/intersection function tables,
793-
// and the SBT buffer) must already be set on the active compute encoder by
794-
// the caller — this method only binds the pipeline state and issues the
795-
// dispatch.
796-
llvm::Error dispatchRays(const PipelineState &PSO, const ShaderBindingTable &,
797-
uint32_t Width, uint32_t Height,
798-
uint32_t Depth) override {
799-
if (!llvm::isa<MTLRayTracingPipelineState>(&PSO))
800-
return llvm::createStringError(
801-
std::errc::invalid_argument,
802-
"dispatchRays requires a RayTracing PipelineState.");
803-
const auto &RTPSO = llvm::cast<MTLRayTracingPipelineState>(PSO);
804-
if (!RTPSO.ComputePipeline)
805-
return llvm::createStringError(
806-
std::errc::invalid_argument,
807-
"RayTracing PipelineState has no compute pipeline state.");
808-
if (auto Err = ensureComputeEncoder())
809-
return Err;
810-
flushBarrier();
811-
insertDebugSignpost(
812-
llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth)
813-
.str());
814-
ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline);
815-
816-
// DispatchRays(W, H, D) launches W*H*D rays; tid in the irconverter raygen
817-
// kernel is the per-ray index. Pass grid as raw (W, H, D) and let Metal
818-
// ceil-divide by ThreadsPerGroup to compute threadgroup count.
819-
const MTL::Size GridSize(Width, Height, Depth);
820-
ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup);
821-
addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures);
822-
return llvm::Error::success();
823-
}
804+
// Issue a ray dispatch: synthesizes the per-dispatch IRDispatchRaysArgument
805+
// (SBT region addresses, GRS + resource-heap pointers, visible/intersection
806+
// function-table IDs), binds it at kIRRayDispatchArgumentsBindPoint, marks
807+
// the RT-specific resources resident, and dispatches the irconverter raygen
808+
// compute kernel. The caller must have already bound the descriptor heap and
809+
// top-level argument buffer (the compute-dispatch path does the same).
810+
// Defined out-of-line below — allocating the argument buffer needs
811+
// MTLDevice's full type.
812+
llvm::Error dispatchRays(const PipelineState &PSO,
813+
const ShaderBindingTable &SBT, uint32_t Width,
814+
uint32_t Height, uint32_t Depth) override;
824815

825816
/// Lazily transition into an AccelerationStructureCommandEncoder; mirrors
826817
/// the existing compute↔blit lazy switch.
@@ -1742,7 +1733,8 @@ class MTLDevice : public offloadtest::Device {
17421733
const auto &PS = llvm::cast<MTLPipelineState>(IS.Pipeline.get());
17431734
MTLGPUDescriptorHandle Handle = {};
17441735
if (IS.DescHeap) {
1745-
IS.DescHeap->bind(NativeEncoder);
1736+
if (auto Err = Encoder.bindResourceHeap(*IS.DescHeap))
1737+
return Err;
17461738
Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart();
17471739
}
17481740

@@ -1772,7 +1764,18 @@ class MTLDevice : public offloadtest::Device {
17721764
NativeEncoder->useResource(llvm::cast<MTLBuffer>(B.get())->Buf,
17731765
MTL::ResourceUsageRead);
17741766

1775-
if (auto Err = Encoder.dispatch(*IS.Pipeline.get(),
1767+
// Metal compiles raygen as a compute kernel, so a ray-tracing pipeline
1768+
// dispatches through the same compute encoder as a plain compute shader;
1769+
// dispatchRays adds the RT-specific argument buffer and residency.
1770+
if (P.isRayTracing()) {
1771+
if (auto Err =
1772+
Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(),
1773+
P.DispatchParameters.DispatchGroupCount[0],
1774+
P.DispatchParameters.DispatchGroupCount[1],
1775+
P.DispatchParameters.DispatchGroupCount[2]))
1776+
return Err;
1777+
} else if (auto Err =
1778+
Encoder.dispatch(*IS.Pipeline.get(),
17761779
P.DispatchParameters.DispatchGroupCount[0],
17771780
P.DispatchParameters.DispatchGroupCount[1],
17781781
P.DispatchParameters.DispatchGroupCount[2]))
@@ -1781,112 +1784,6 @@ class MTLDevice : public offloadtest::Device {
17811784
return llvm::Error::success();
17821785
}
17831786

1784-
llvm::Error createRayTracingCommands(Pipeline &P, InvocationState &IS) {
1785-
auto EncoderOrErr = IS.CB->createComputeEncoder();
1786-
if (!EncoderOrErr)
1787-
return EncoderOrErr.takeError();
1788-
auto &Encoder = llvm::cast<MTLComputeEncoder>(*EncoderOrErr.get());
1789-
MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative();
1790-
1791-
const auto &RTPSO =
1792-
llvm::cast<MTLRayTracingPipelineState>(*IS.Pipeline.get());
1793-
const auto &SBT = llvm::cast<MTLShaderBindingTable>(*IS.SBT.get());
1794-
1795-
// Bind the global descriptor heap + top-level argument buffer the same
1796-
// way the compute path does; the raygen kernel and any visible-function
1797-
// callees consume them at the same slots (kIRDescriptorHeapBindPoint and
1798-
// kIRArgumentBufferBindPoint).
1799-
MTLGPUDescriptorHandle Handle = {};
1800-
if (IS.DescHeap) {
1801-
IS.DescHeap->bind(NativeEncoder);
1802-
Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart();
1803-
}
1804-
for (uint32_t Idx = 0u; Idx < P.Sets.size(); ++Idx) {
1805-
RTPSO.ArgBuffer->setRootDescriptorTable(Idx, Handle);
1806-
Handle.addOffset(P.Sets[Idx].Resources.size());
1807-
}
1808-
RTPSO.ArgBuffer->bind(NativeEncoder);
1809-
1810-
// Populate the per-dispatch IRDispatchRaysArgument: SBT region addresses
1811-
// (RayGen / Miss / HitGroup / Callable), GPU pointers to the global
1812-
// root-signature argument buffer + descriptor heaps, plus resource IDs
1813-
// for the visible / intersection function tables. The raygen kernel
1814-
// reads this struct from the buffer bound at kIRRayDispatchArgumentsBind-
1815-
// Point and any visible-function callees inherit it through the same
1816-
// pointer.
1817-
IRDispatchRaysArgument Args{};
1818-
Args.DispatchRaysDesc.RayGenerationShaderRecord = SBT.RayGenRegion;
1819-
Args.DispatchRaysDesc.MissShaderTable = SBT.MissRegion;
1820-
Args.DispatchRaysDesc.HitGroupTable = SBT.HitGroupRegion;
1821-
Args.DispatchRaysDesc.CallableShaderTable = SBT.CallableRegion;
1822-
Args.DispatchRaysDesc.Width = P.DispatchParameters.DispatchGroupCount[0];
1823-
Args.DispatchRaysDesc.Height = P.DispatchParameters.DispatchGroupCount[1];
1824-
Args.DispatchRaysDesc.Depth = P.DispatchParameters.DispatchGroupCount[2];
1825-
Args.GRS = RTPSO.ArgBuffer->getGPUAddress();
1826-
Args.ResDescHeap =
1827-
IS.DescHeap ? IS.DescHeap->getGPUDescriptorHandleForHeapStart().Ptr : 0;
1828-
Args.SmpDescHeap = 0;
1829-
Args.VisibleFunctionTable =
1830-
RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0};
1831-
Args.IntersectionFunctionTable =
1832-
RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0};
1833-
Args.IntersectionFunctionTables = 0;
1834-
1835-
const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer();
1836-
auto ArgsBufOrErr = offloadtest::createBufferWithData(
1837-
*IS.CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args,
1838-
sizeof(IRDispatchRaysArgument), nullptr, nullptr);
1839-
if (!ArgsBufOrErr)
1840-
return ArgsBufOrErr.takeError();
1841-
1842-
auto *MTLArgsBuf = llvm::cast<MTLBuffer>(ArgsBufOrErr->get());
1843-
IS.CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr));
1844-
1845-
NativeEncoder->setBuffer(MTLArgsBuf->Buf, 0,
1846-
kIRRayDispatchArgumentsBindPoint);
1847-
NativeEncoder->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead);
1848-
1849-
// Mark every dispatch-side resource resident: descriptor-table bundles,
1850-
// acceleration structures + their irconverter header/contribution
1851-
// buffers (so RayQuery/TraceRay can read them), the SBT buffer (the
1852-
// raygen kernel dereferences SBT addresses), and the visible /
1853-
// intersection function tables.
1854-
for (const auto &Table : IS.DescTables)
1855-
for (const auto &ResPair : Table.Resources)
1856-
for (const auto &ResSet : ResPair.second)
1857-
NativeEncoder->useResource(ResSet.Resource.get(),
1858-
MTL::ResourceUsageRead |
1859-
MTL::ResourceUsageWrite);
1860-
auto MarkASResident =
1861-
[&](std::unique_ptr<offloadtest::AccelerationStructure> &AS) {
1862-
auto *MTLAS = llvm::cast<MetalAccelerationStructure>(AS.get());
1863-
NativeEncoder->useResource(MTLAS->AccelStruct,
1864-
MTL::ResourceUsageRead);
1865-
};
1866-
for (auto &AS : IS.BLASes)
1867-
MarkASResident(AS);
1868-
for (auto &Entry : IS.TLASes)
1869-
MarkASResident(Entry.second);
1870-
for (auto &B : IS.ASDescriptorBuffers)
1871-
NativeEncoder->useResource(llvm::cast<MTLBuffer>(B.get())->Buf,
1872-
MTL::ResourceUsageRead);
1873-
if (SBT.Buffer)
1874-
NativeEncoder->useResource(SBT.Buffer, MTL::ResourceUsageRead);
1875-
if (RTPSO.VFT)
1876-
NativeEncoder->useResource(RTPSO.VFT, MTL::ResourceUsageRead);
1877-
if (RTPSO.IFT)
1878-
NativeEncoder->useResource(RTPSO.IFT, MTL::ResourceUsageRead);
1879-
1880-
if (auto Err =
1881-
Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(),
1882-
P.DispatchParameters.DispatchGroupCount[0],
1883-
P.DispatchParameters.DispatchGroupCount[1],
1884-
P.DispatchParameters.DispatchGroupCount[2]))
1885-
return Err;
1886-
Encoder.endEncoding();
1887-
return llvm::Error::success();
1888-
}
1889-
18901787
llvm::Error createRenderTarget(Pipeline &P, InvocationState &IS) {
18911788
if (!P.Bindings.RTargetBufferPtr)
18921789
return llvm::createStringError(
@@ -3094,7 +2991,9 @@ class MTLDevice : public offloadtest::Device {
30942991
IS.SBT = std::move(*SBTOrErr);
30952992
llvm::outs() << "Shader Binding Table created.\n";
30962993

3097-
if (auto Err = createRayTracingCommands(P, IS))
2994+
// Metal lowers raygen to a compute kernel, so ray tracing records its
2995+
// dispatch through the shared compute-command path.
2996+
if (auto Err = createComputeCommands(P, IS))
30982997
return Err;
30992998
}
31002999

@@ -3303,6 +3202,88 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) {
33033202

33043203
return llvm::Error::success();
33053204
}
3205+
3206+
llvm::Error MTLComputeEncoder::dispatchRays(const PipelineState &PSO,
3207+
const ShaderBindingTable &SBT,
3208+
uint32_t Width, uint32_t Height,
3209+
uint32_t Depth) {
3210+
if (!llvm::isa<MTLRayTracingPipelineState>(&PSO))
3211+
return llvm::createStringError(
3212+
std::errc::invalid_argument,
3213+
"dispatchRays requires a RayTracing PipelineState.");
3214+
if (!llvm::isa<MTLShaderBindingTable>(&SBT))
3215+
return llvm::createStringError(
3216+
std::errc::invalid_argument,
3217+
"dispatchRays requires a Metal ShaderBindingTable.");
3218+
const auto &RTPSO = llvm::cast<MTLRayTracingPipelineState>(PSO);
3219+
const auto &MTLSBT = llvm::cast<MTLShaderBindingTable>(SBT);
3220+
if (!RTPSO.ComputePipeline)
3221+
return llvm::createStringError(
3222+
std::errc::invalid_argument,
3223+
"RayTracing PipelineState has no compute pipeline state.");
3224+
if (auto Err = ensureComputeEncoder())
3225+
return Err;
3226+
flushBarrier();
3227+
insertDebugSignpost(
3228+
llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth).str());
3229+
3230+
// Populate the per-dispatch IRDispatchRaysArgument: SBT region addresses
3231+
// (RayGen / Miss / HitGroup / Callable), GPU pointers to the global
3232+
// root-signature argument buffer + descriptor heaps, plus resource IDs for
3233+
// the visible / intersection function tables. The raygen kernel reads this
3234+
// struct from the buffer bound at kIRRayDispatchArgumentsBindPoint and any
3235+
// visible-function callees inherit it through the same pointer.
3236+
IRDispatchRaysArgument Args{};
3237+
Args.DispatchRaysDesc.RayGenerationShaderRecord = MTLSBT.RayGenRegion;
3238+
Args.DispatchRaysDesc.MissShaderTable = MTLSBT.MissRegion;
3239+
Args.DispatchRaysDesc.HitGroupTable = MTLSBT.HitGroupRegion;
3240+
Args.DispatchRaysDesc.CallableShaderTable = MTLSBT.CallableRegion;
3241+
Args.DispatchRaysDesc.Width = Width;
3242+
Args.DispatchRaysDesc.Height = Height;
3243+
Args.DispatchRaysDesc.Depth = Depth;
3244+
Args.GRS = RTPSO.ArgBuffer->getGPUAddress();
3245+
Args.ResDescHeap =
3246+
BoundResourceHeap
3247+
? BoundResourceHeap->getGPUDescriptorHandleForHeapStart().Ptr
3248+
: 0;
3249+
Args.SmpDescHeap = 0;
3250+
Args.VisibleFunctionTable =
3251+
RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0};
3252+
Args.IntersectionFunctionTable =
3253+
RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0};
3254+
Args.IntersectionFunctionTables = 0;
3255+
3256+
const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer();
3257+
auto ArgsBufOrErr = offloadtest::createBufferWithData(
3258+
*CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args,
3259+
sizeof(IRDispatchRaysArgument), nullptr, nullptr);
3260+
if (!ArgsBufOrErr)
3261+
return ArgsBufOrErr.takeError();
3262+
auto *MTLArgsBuf = llvm::cast<MTLBuffer>(ArgsBufOrErr->get());
3263+
CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr));
3264+
3265+
ComputeEnc->setBuffer(MTLArgsBuf->Buf, 0, kIRRayDispatchArgumentsBindPoint);
3266+
ComputeEnc->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead);
3267+
3268+
// RT-specific residency beyond what the shared compute path marks: the SBT
3269+
// buffer (the raygen kernel dereferences SBT addresses) and the visible /
3270+
// intersection function tables.
3271+
if (MTLSBT.Buffer)
3272+
ComputeEnc->useResource(MTLSBT.Buffer, MTL::ResourceUsageRead);
3273+
if (RTPSO.VFT)
3274+
ComputeEnc->useResource(RTPSO.VFT, MTL::ResourceUsageRead);
3275+
if (RTPSO.IFT)
3276+
ComputeEnc->useResource(RTPSO.IFT, MTL::ResourceUsageRead);
3277+
3278+
ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline);
3279+
// DispatchRays(W, H, D) launches W*H*D rays; tid in the irconverter raygen
3280+
// kernel is the per-ray index. Pass grid as raw (W, H, D) and let Metal
3281+
// ceil-divide by ThreadsPerGroup to compute threadgroup count.
3282+
const MTL::Size GridSize(Width, Height, Depth);
3283+
ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup);
3284+
addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures);
3285+
return llvm::Error::success();
3286+
}
33063287
} // namespace
33073288

33083289
llvm::Error offloadtest::initializeMetalDevices(

0 commit comments

Comments
 (0)