Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 98 additions & 143 deletions lib/API/MTL/MTLDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
/// the next encoder transition (via endEncodingImpl).
MTL::AccelerationStructureCommandEncoder *ASEnc = nullptr;

/// Bound resource heap, read back for IRDispatchRaysArgument::ResDescHeap.
MTLDescriptorHeap *BoundResourceHeap = nullptr;

/// Accumulated barrier scope from commands recorded since the last barrier.
MTL::BarrierScope PendingScope = MTL::BarrierScope(0);

Expand Down Expand Up @@ -665,6 +668,14 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {

MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; }

llvm::Error bindResourceHeap(MTLDescriptorHeap &Heap) {
if (auto Err = ensureComputeEncoder())
return Err;
Heap.bind(ComputeEnc);
BoundResourceHeap = &Heap;
return llvm::Error::success();
}

MTL::CommandEncoder *getActiveEncoder() const {
if (ComputeEnc)
return ComputeEnc;
Expand Down Expand Up @@ -787,40 +798,9 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
// MTL::Device handle (used to allocate scratch and instance buffers).
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;

// Dispatch threads using a raygen compute kernel synthesized by the
// irconverter. All bindings (descriptor heap, top-level argument buffer,
// IRDispatchRaysArgument at slot 3, visible/intersection function tables,
// and the SBT buffer) must already be set on the active compute encoder by
// the caller — this method only binds the pipeline state and issues the
// dispatch.
llvm::Error dispatchRays(const PipelineState &PSO, const ShaderBindingTable &,
uint32_t Width, uint32_t Height,
uint32_t Depth) override {
if (!llvm::isa<MTLRayTracingPipelineState>(&PSO))
return llvm::createStringError(
std::errc::invalid_argument,
"dispatchRays requires a RayTracing PipelineState.");
const auto &RTPSO = llvm::cast<MTLRayTracingPipelineState>(PSO);
if (!RTPSO.ComputePipeline)
return llvm::createStringError(
std::errc::invalid_argument,
"RayTracing PipelineState has no compute pipeline state.");
if (auto Err = ensureComputeEncoder())
return Err;
flushBarrier();
insertDebugSignpost(
llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth)
.str());
ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline);

// DispatchRays(W, H, D) launches W*H*D rays; tid in the irconverter raygen
// kernel is the per-ray index. Pass grid as raw (W, H, D) and let Metal
// ceil-divide by ThreadsPerGroup to compute threadgroup count.
const MTL::Size GridSize(Width, Height, Depth);
ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup);
addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures);
return llvm::Error::success();
}
llvm::Error dispatchRays(const PipelineState &PSO,
const ShaderBindingTable &SBT, uint32_t Width,
uint32_t Height, uint32_t Depth) override;

/// Lazily transition into an AccelerationStructureCommandEncoder; mirrors
/// the existing compute↔blit lazy switch.
Comment thread
MarijnS95 marked this conversation as resolved.
Expand Down Expand Up @@ -1742,7 +1722,8 @@ class MTLDevice : public offloadtest::Device {
const auto &PS = llvm::cast<MTLPipelineState>(IS.Pipeline.get());
MTLGPUDescriptorHandle Handle = {};
if (IS.DescHeap) {
IS.DescHeap->bind(NativeEncoder);
if (auto Err = Encoder.bindResourceHeap(*IS.DescHeap))
return Err;
Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart();
}

Expand Down Expand Up @@ -1772,7 +1753,15 @@ class MTLDevice : public offloadtest::Device {
NativeEncoder->useResource(llvm::cast<MTLBuffer>(B.get())->Buf,
MTL::ResourceUsageRead);

if (auto Err = Encoder.dispatch(*IS.Pipeline.get(),
if (P.isRayTracing()) {
if (auto Err =
Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(),
P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
return Err;
} else if (auto Err =
Encoder.dispatch(*IS.Pipeline.get(),
P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
Expand All @@ -1781,112 +1770,6 @@ class MTLDevice : public offloadtest::Device {
return llvm::Error::success();
}

llvm::Error createRayTracingCommands(Pipeline &P, InvocationState &IS) {
auto EncoderOrErr = IS.CB->createComputeEncoder();
if (!EncoderOrErr)
return EncoderOrErr.takeError();
auto &Encoder = llvm::cast<MTLComputeEncoder>(*EncoderOrErr.get());
MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative();

const auto &RTPSO =
llvm::cast<MTLRayTracingPipelineState>(*IS.Pipeline.get());
const auto &SBT = llvm::cast<MTLShaderBindingTable>(*IS.SBT.get());

// Bind the global descriptor heap + top-level argument buffer the same
// way the compute path does; the raygen kernel and any visible-function
// callees consume them at the same slots (kIRDescriptorHeapBindPoint and
// kIRArgumentBufferBindPoint).
MTLGPUDescriptorHandle Handle = {};
if (IS.DescHeap) {
IS.DescHeap->bind(NativeEncoder);
Handle = IS.DescHeap->getGPUDescriptorHandleForHeapStart();
}
for (uint32_t Idx = 0u; Idx < P.Sets.size(); ++Idx) {
RTPSO.ArgBuffer->setRootDescriptorTable(Idx, Handle);
Handle.addOffset(P.Sets[Idx].Resources.size());
}
RTPSO.ArgBuffer->bind(NativeEncoder);

// Populate the per-dispatch IRDispatchRaysArgument: SBT region addresses
// (RayGen / Miss / HitGroup / Callable), GPU pointers to the global
// root-signature argument buffer + descriptor heaps, plus resource IDs
// for the visible / intersection function tables. The raygen kernel
// reads this struct from the buffer bound at kIRRayDispatchArgumentsBind-
// Point and any visible-function callees inherit it through the same
// pointer.
IRDispatchRaysArgument Args{};
Args.DispatchRaysDesc.RayGenerationShaderRecord = SBT.RayGenRegion;
Args.DispatchRaysDesc.MissShaderTable = SBT.MissRegion;
Args.DispatchRaysDesc.HitGroupTable = SBT.HitGroupRegion;
Args.DispatchRaysDesc.CallableShaderTable = SBT.CallableRegion;
Args.DispatchRaysDesc.Width = P.DispatchParameters.DispatchGroupCount[0];
Args.DispatchRaysDesc.Height = P.DispatchParameters.DispatchGroupCount[1];
Args.DispatchRaysDesc.Depth = P.DispatchParameters.DispatchGroupCount[2];
Args.GRS = RTPSO.ArgBuffer->getGPUAddress();
Args.ResDescHeap =
IS.DescHeap ? IS.DescHeap->getGPUDescriptorHandleForHeapStart().Ptr : 0;
Args.SmpDescHeap = 0;
Args.VisibleFunctionTable =
RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0};
Args.IntersectionFunctionTable =
RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0};
Args.IntersectionFunctionTables = 0;

const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer();
auto ArgsBufOrErr = offloadtest::createBufferWithData(
*IS.CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args,
sizeof(IRDispatchRaysArgument), nullptr, nullptr);
if (!ArgsBufOrErr)
return ArgsBufOrErr.takeError();

auto *MTLArgsBuf = llvm::cast<MTLBuffer>(ArgsBufOrErr->get());
IS.CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr));

NativeEncoder->setBuffer(MTLArgsBuf->Buf, 0,
kIRRayDispatchArgumentsBindPoint);
NativeEncoder->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead);

// Mark every dispatch-side resource resident: descriptor-table bundles,
// acceleration structures + their irconverter header/contribution
// buffers (so RayQuery/TraceRay can read them), the SBT buffer (the
// raygen kernel dereferences SBT addresses), and the visible /
// intersection function tables.
for (const auto &Table : IS.DescTables)
for (const auto &ResPair : Table.Resources)
for (const auto &ResSet : ResPair.second)
NativeEncoder->useResource(ResSet.Resource.get(),
MTL::ResourceUsageRead |
MTL::ResourceUsageWrite);
auto MarkASResident =
[&](std::unique_ptr<offloadtest::AccelerationStructure> &AS) {
auto *MTLAS = llvm::cast<MetalAccelerationStructure>(AS.get());
NativeEncoder->useResource(MTLAS->AccelStruct,
MTL::ResourceUsageRead);
};
for (auto &AS : IS.BLASes)
MarkASResident(AS);
for (auto &Entry : IS.TLASes)
MarkASResident(Entry.second);
for (auto &B : IS.ASDescriptorBuffers)
NativeEncoder->useResource(llvm::cast<MTLBuffer>(B.get())->Buf,
MTL::ResourceUsageRead);
if (SBT.Buffer)
NativeEncoder->useResource(SBT.Buffer, MTL::ResourceUsageRead);
if (RTPSO.VFT)
NativeEncoder->useResource(RTPSO.VFT, MTL::ResourceUsageRead);
if (RTPSO.IFT)
NativeEncoder->useResource(RTPSO.IFT, MTL::ResourceUsageRead);

if (auto Err =
Encoder.dispatchRays(*IS.Pipeline.get(), *IS.SBT.get(),
P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
return Err;
Encoder.endEncoding();
return llvm::Error::success();
}

llvm::Error createRenderTarget(Pipeline &P, InvocationState &IS) {
if (!P.Bindings.RTargetBufferPtr)
return llvm::createStringError(
Expand Down Expand Up @@ -3094,7 +2977,7 @@ class MTLDevice : public offloadtest::Device {
IS.SBT = std::move(*SBTOrErr);
llvm::outs() << "Shader Binding Table created.\n";

if (auto Err = createRayTracingCommands(P, IS))
if (auto Err = createComputeCommands(P, IS))
return Err;
}

Expand Down Expand Up @@ -3303,6 +3186,78 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) {

return llvm::Error::success();
}

llvm::Error MTLComputeEncoder::dispatchRays(const PipelineState &PSO,
const ShaderBindingTable &SBT,
uint32_t Width, uint32_t Height,
uint32_t Depth) {
if (!llvm::isa<MTLRayTracingPipelineState>(&PSO))
return llvm::createStringError(
std::errc::invalid_argument,
"dispatchRays requires a RayTracing PipelineState.");
if (!llvm::isa<MTLShaderBindingTable>(&SBT))
return llvm::createStringError(
std::errc::invalid_argument,
"dispatchRays requires a Metal ShaderBindingTable.");
const auto &RTPSO = llvm::cast<MTLRayTracingPipelineState>(PSO);
const auto &MTLSBT = llvm::cast<MTLShaderBindingTable>(SBT);
if (!RTPSO.ComputePipeline)
return llvm::createStringError(
std::errc::invalid_argument,
"RayTracing PipelineState has no compute pipeline state.");
if (auto Err = ensureComputeEncoder())
return Err;
flushBarrier();
insertDebugSignpost(
llvm::formatv("DispatchRays [{0},{1},{2}]", Width, Height, Depth).str());

// Per-dispatch ray arguments, consumed at kIRRayDispatchArgumentsBindPoint.
IRDispatchRaysArgument Args{};
Args.DispatchRaysDesc.RayGenerationShaderRecord = MTLSBT.RayGenRegion;
Args.DispatchRaysDesc.MissShaderTable = MTLSBT.MissRegion;
Args.DispatchRaysDesc.HitGroupTable = MTLSBT.HitGroupRegion;
Args.DispatchRaysDesc.CallableShaderTable = MTLSBT.CallableRegion;
Args.DispatchRaysDesc.Width = Width;
Args.DispatchRaysDesc.Height = Height;
Args.DispatchRaysDesc.Depth = Depth;
Args.GRS = RTPSO.ArgBuffer->getGPUAddress();
Args.ResDescHeap =
BoundResourceHeap
? BoundResourceHeap->getGPUDescriptorHandleForHeapStart().Ptr
: 0;
Args.SmpDescHeap = 0;
Args.VisibleFunctionTable =
RTPSO.VFT ? RTPSO.VFT->gpuResourceID() : MTL::ResourceID{0};
Args.IntersectionFunctionTable =
RTPSO.IFT ? RTPSO.IFT->gpuResourceID() : MTL::ResourceID{0};
Args.IntersectionFunctionTables = 0;

const BufferCreateDesc ArgsBufDesc = BufferCreateDesc::uploadBuffer();
auto ArgsBufOrErr = offloadtest::createBufferWithData(
*CB->Dev, "MTL Dispatch Rays Arguments", ArgsBufDesc, &Args,
sizeof(IRDispatchRaysArgument), nullptr, nullptr);
if (!ArgsBufOrErr)
return ArgsBufOrErr.takeError();
auto *MTLArgsBuf = llvm::cast<MTLBuffer>(ArgsBufOrErr->get());
CB->KeepAliveOwned.push_back(std::move(*ArgsBufOrErr));

ComputeEnc->setBuffer(MTLArgsBuf->Buf, 0, kIRRayDispatchArgumentsBindPoint);
ComputeEnc->useResource(MTLArgsBuf->Buf, MTL::ResourceUsageRead);

if (MTLSBT.Buffer)
ComputeEnc->useResource(MTLSBT.Buffer, MTL::ResourceUsageRead);
if (RTPSO.VFT)
ComputeEnc->useResource(RTPSO.VFT, MTL::ResourceUsageRead);
if (RTPSO.IFT)
ComputeEnc->useResource(RTPSO.IFT, MTL::ResourceUsageRead);

ComputeEnc->setComputePipelineState(RTPSO.ComputePipeline);
// dispatchThreads automatically takes care of bounds checking.
const MTL::Size GridSize(Width, Height, Depth);
ComputeEnc->dispatchThreads(GridSize, RTPSO.ThreadsPerGroup);
addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures);
return llvm::Error::success();
}
} // namespace

llvm::Error offloadtest::initializeMetalDevices(
Expand Down
Loading