Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 242 additions & 43 deletions lib/API/MTL/MTLDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
MTL::DepthStencilState *DepthStencilState = nullptr;
MTL::CullMode CullMode = MTL::CullModeNone;

MTL::Size MeshThreadsPerThreadgroup{1, 1, 1};
MTL::Size ObjectThreadsPerThreadgroup{1, 1, 1};

MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
MTL::ComputePipelineState *ComputePipeline,
Expand All @@ -266,11 +269,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
MTL::RenderPipelineState *RenderPipeline,
MTL::DepthStencilState *DepthStencilState,
MTL::CullMode CullMode)
MTL::CullMode CullMode,
MTL::Size MeshThreadsPerThreadgroup = {1, 1, 1},
MTL::Size ObjectThreadsPerThreadgroup = {1, 1, 1})
: offloadtest::PipelineState(GPUAPI::Metal), Name(Name),
RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
CullMode(CullMode) {}
CullMode(CullMode),
MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}

~MTLPipelineState() override {
if (ComputePipeline)
Expand Down Expand Up @@ -710,13 +717,30 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
llvm::Error dispatchMesh(const offloadtest::PipelineState &PSO,
uint32_t GroupCountX, uint32_t GroupCountY,
uint32_t GroupCountZ) override {
(void)PSO;
(void)GroupCountX;
(void)GroupCountY;
(void)GroupCountZ;
if (!ViewportSet)
return llvm::createStringError(std::errc::invalid_argument,
"Viewport must be set before drawing.");
if (!ScissorSet)
return llvm::createStringError(std::errc::invalid_argument,
"Scissor must be set before drawing.");

return llvm::createStringError(
"dispatchMesh is unimplemented in the Metal backend.");
const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
if (!MTLPSO.RenderPipeline)
return llvm::createStringError(
std::errc::invalid_argument,
"PipelineState bound to dispatchMesh() is not a render pipeline.");
RenderEnc->setRenderPipelineState(MTLPSO.RenderPipeline);
if (MTLPSO.DepthStencilState)
RenderEnc->setDepthStencilState(MTLPSO.DepthStencilState);
RenderEnc->setCullMode(MTLPSO.CullMode);
// Match the DX/VK convention (CCW = front) hardcoded in those backends.
RenderEnc->setFrontFacingWinding(MTL::WindingCounterClockwise);

RenderEnc->drawMeshThreadgroups(
MTL::Size(GroupCountX, GroupCountY, GroupCountZ),
MTLPSO.ObjectThreadsPerThreadgroup, MTLPSO.MeshThreadsPerThreadgroup);

return llvm::Error::success();
}

void endEncodingImpl() override {
Expand Down Expand Up @@ -1401,12 +1425,23 @@ class MTLDevice : public offloadtest::Device {
Scissor.Height = static_cast<uint32_t>(Height);
Encoder.setScissor(Scissor);

if (IS.VB)
Encoder.setVertexBuffer(0, IS.VB.get(), 0, P.Bindings.getVertexStride());
if (P.isTraditionalRaster()) {
if (IS.VB)
Encoder.setVertexBuffer(0, IS.VB.get(), 0,
P.Bindings.getVertexStride());

if (auto Err =
Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
/*InstanceCount=*/1))
return Err;
} else {
if (auto Err = Encoder.dispatchMesh(
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
return Err;
}

if (auto Err = Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
/*InstanceCount=*/1))
return Err;
Encoder.endEncoding();

// Blit the render target into the readback buffer for CPU access.
Expand Down Expand Up @@ -1760,6 +1795,136 @@ class MTLDevice : public offloadtest::Device {
DSState, MTL::CullModeNone);
}

llvm::Expected<std::unique_ptr<PipelineState>>
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
llvm::ArrayRef<Format> RTFormats,
std::optional<Format> DSFormat,
std::optional<ShaderContainer> AS, ShaderContainer MS,
std::optional<ShaderContainer> PS) {
IRRootSignaturePtr RootSig;
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
if (auto Err = createRootSignature(BindingsDesc, /*IsGraphics=*/true,
RootSig, ArgBuffer))
return Err;

NS::Error *Error = nullptr;
auto compileStage = [&](Stages Stage, const ShaderContainer &SC,
llvm::StringRef RoleName, MetalIR &OutIR,
MTLPtr<MTL::Library> &OutLib,
MTLPtr<MTL::Function> &OutFn) -> llvm::Error {
auto IROrErr =
convertToMetalIR(Stage, /*IsGraphics=*/true, RootSig.get(), SC);
if (!IROrErr)
return IROrErr.takeError();
OutIR = std::move(*IROrErr);

dispatch_data_t Data = IRMetalLibGetBytecodeData(OutIR.Binary.get());
NS::Error *Err = nullptr;
OutLib = MTLPtr<MTL::Library>(Device->newLibrary(Data, &Err));
if (Err)
return toError(Err);

OutFn = MTLPtr<MTL::Function>(OutLib->newFunction(
NS::String::string(SC.EntryPoint.c_str(), NS::UTF8StringEncoding)));
if (!OutFn)
return llvm::createStringError(
std::errc::invalid_argument,
"Failed to find %s entry point '%s' in Metal library.",
RoleName.data(), SC.EntryPoint.c_str());
return llvm::Error::success();
};

MetalIR MSIR;
MTLPtr<MTL::Library> MSLib;
MTLPtr<MTL::Function> MSFn;
if (auto Err = compileStage(Stages::Mesh, MS, "mesh", MSIR, MSLib, MSFn))
return Err;

MetalIR ASIR;
MTLPtr<MTL::Library> ASLib;
MTLPtr<MTL::Function> ASFn;
if (AS) {
if (auto Err = compileStage(Stages::Amplification, *AS, "amplification",
ASIR, ASLib, ASFn))
return Err;
}

MetalIR PSIR;
MTLPtr<MTL::Library> PSLib;
MTLPtr<MTL::Function> PSFn;
if (PS) {
if (auto Err =
compileStage(Stages::Pixel, *PS, "fragment", PSIR, PSLib, PSFn))
return Err;
}

MTL::MeshRenderPipelineDescriptor *Desc =
MTL::MeshRenderPipelineDescriptor::alloc()->init();
auto DescScope = llvm::scope_exit([&] { Desc->release(); });

Desc->setMeshFunction(MSFn.get());
if (ASFn)
Desc->setObjectFunction(ASFn.get());
if (PSFn)
Desc->setFragmentFunction(PSFn.get());

for (size_t I = 0; I < RTFormats.size(); ++I) {
MTL::RenderPipelineColorAttachmentDescriptor *RPCA =
MTL::RenderPipelineColorAttachmentDescriptor::alloc()->init();
RPCA->setPixelFormat(getMetalPixelFormat(RTFormats[I]));
Desc->colorAttachments()->setObject(RPCA, I);
RPCA->release();
}

if (DSFormat) {
const MTL::PixelFormat DSPixelFormat = getMetalPixelFormat(*DSFormat);
Desc->setDepthAttachmentPixelFormat(DSPixelFormat);
if (isStencilFormat(*DSFormat))
Desc->setStencilAttachmentPixelFormat(DSPixelFormat);
}

MTL::RenderPipelineState *PSO = Device->newRenderPipelineState(
Desc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
if (Error)
return toError(Error);

MTL::DepthStencilDescriptor *DSDesc =
MTL::DepthStencilDescriptor::alloc()->init();
DSDesc->setDepthCompareFunction(MTL::CompareFunctionLess);
DSDesc->setDepthWriteEnabled(true);
MTL::DepthStencilState *DSState = Device->newDepthStencilState(DSDesc);
DSDesc->release();

// Pull threads-per-threadgroup from shader reflection.
MTL::Size MeshTGSize(1, 1, 1);
{
IRVersionedMSInfo MSInfo;
if (IRShaderReflectionCopyMeshInfo(MSIR.Reflection.get(),
IRReflectionVersion_1_0, &MSInfo)) {
MeshTGSize = MTL::Size(MSInfo.info_1_0.num_threads[0],
MSInfo.info_1_0.num_threads[1],
MSInfo.info_1_0.num_threads[2]);
}
IRShaderReflectionReleaseMeshInfo(&MSInfo);
}

MTL::Size ObjectTGSize(1, 1, 1);
if (AS) {
IRVersionedASInfo ASInfo;
if (IRShaderReflectionCopyAmplificationInfo(
ASIR.Reflection.get(), IRReflectionVersion_1_0, &ASInfo)) {
ObjectTGSize = MTL::Size(ASInfo.info_1_0.num_threads[0],
ASInfo.info_1_0.num_threads[1],
ASInfo.info_1_0.num_threads[2]);
}
IRShaderReflectionReleaseAmplificationInfo(&ASInfo);
}

return std::make_unique<MTLPipelineState>(
Name, std::move(RootSig), std::move(ArgBuffer), PSO, DSState,
MTL::CullModeNone, MeshTGSize, ObjectTGSize);
}

llvm::Error executeProgram(Pipeline &P) override {
InvocationState IS;

Expand Down Expand Up @@ -1808,40 +1973,68 @@ class MTLDevice : public offloadtest::Device {

if (auto Err = createComputeCommands(P, IS))
return Err;
} else {
TraditionalRasterPipelineCreateDesc PipelineDesc = {};
PipelineDesc.Topology = P.Bindings.Topology;
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
for (auto &Shader : P.Shaders) {
ShaderContainer SC = {};
SC.EntryPoint = Shader.Entry;
SC.Shader = Shader.Shader.get();
PipelineDesc.setShader(Shader.Stage, std::move(SC));
}

for (auto &Attr : P.Bindings.VertexAttributes) {
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
if (!FormatOrErr)
return FormatOrErr.takeError();

InputLayoutDesc Layout = {};
Layout.Name = Attr.Name;
Layout.Fmt = *FormatOrErr;
Layout.OffsetInBytes = Attr.Offset;
PipelineDesc.InputLayout.push_back(Layout);
}

} else if (P.isRaster()) {
auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format,
P.Bindings.RTargetBufferPtr->Channels);

llvm::SmallVector<Format> RTFormats;
if (!FormatOrErr)
return FormatOrErr.takeError();
PipelineDesc.RTFormats.push_back(*FormatOrErr);
RTFormats.push_back(*FormatOrErr);

if (P.isTraditionalRaster()) {
TraditionalRasterPipelineCreateDesc PipelineDesc = {};
PipelineDesc.Topology = P.Bindings.Topology;
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
PipelineDesc.RTFormats = RTFormats;
for (auto &Shader : P.Shaders) {
ShaderContainer SC = {};
SC.EntryPoint = Shader.Entry;
SC.Shader = Shader.Shader.get();
PipelineDesc.setShader(Shader.Stage, std::move(SC));
}

auto PipelineStateOrErr = createTraditionalRasterPipeline(
"Graphics Pipeline State", Bindings, PipelineDesc);
if (!PipelineStateOrErr)
return PipelineStateOrErr.takeError();
IS.Pipeline = std::move(*PipelineStateOrErr);
for (auto &Attr : P.Bindings.VertexAttributes) {
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
if (!FormatOrErr)
return FormatOrErr.takeError();

InputLayoutDesc Layout = {};
Layout.Name = Attr.Name;
Layout.Fmt = *FormatOrErr;
Layout.OffsetInBytes = Attr.Offset;
PipelineDesc.InputLayout.push_back(Layout);
}

auto PipelineStateOrErr = createTraditionalRasterPipeline(
"Graphics Pipeline State", Bindings, PipelineDesc);
if (!PipelineStateOrErr)
return PipelineStateOrErr.takeError();
IS.Pipeline = std::move(*PipelineStateOrErr);
} else if (P.isMeshShaderRaster()) {
std::optional<ShaderContainer> AS;
ShaderContainer MS = {};
std::optional<ShaderContainer> PS;
for (auto &Shader : P.Shaders) {
ShaderContainer SC = {};
SC.EntryPoint = Shader.Entry;
SC.Shader = Shader.Shader.get();
if (Shader.Stage == Stages::Amplification)
AS = std::move(SC);
else if (Shader.Stage == Stages::Mesh)
MS = std::move(SC);
else if (Shader.Stage == Stages::Pixel)
PS = std::move(SC);
}

auto PipelineStateOrErr =
createPipelineAsMsPs("Mesh Shader Pipeline State", Bindings,
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
if (!PipelineStateOrErr)
return PipelineStateOrErr.takeError();
IS.Pipeline = std::move(*PipelineStateOrErr);
llvm::outs() << "Mesh Shader Pipeline created.\n";
}

ColorAttachmentFormatDesc ColorAttachment = {};
ColorAttachment.Fmt = *FormatOrErr;
Expand Down Expand Up @@ -1883,7 +2076,13 @@ class MTLDevice : public offloadtest::Device {
virtual ~MTLDevice() {};

private:
void queryCapabilities() {}
void queryCapabilities() {
// GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
const bool MeshShaderSupported =
Device->supportsFamily(MTL::GPUFamilyMetal3);
Caps.insert(std::make_pair(
"MeshShader", makeCapability<bool>("MeshShader", MeshShaderSupported)));
}
};
} // namespace

Expand Down
2 changes: 2 additions & 0 deletions test/lit.cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ def setDeviceFeatures(config, device, compiler):
config.available_features.add("Int16")
config.available_features.add("Int64")
config.available_features.add("Half")
if device["Features"].get("MeshShader", False):
config.available_features.add("MeshShader")

if device["API"] == "Vulkan":
if device["Features"].get("shaderInt16", False):
Expand Down
Loading