Skip to content

Commit 05afb4c

Browse files
Add Mesh Shader support to Metal. (#1218)
Add support for Mesh Shaders to Metal. Graphics/MeshShaders/SimpleTriangle.test now passes on Metal.
1 parent 1affeef commit 05afb4c

2 files changed

Lines changed: 244 additions & 43 deletions

File tree

lib/API/MTL/MTLDevice.cpp

Lines changed: 242 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
258258
MTL::DepthStencilState *DepthStencilState = nullptr;
259259
MTL::CullMode CullMode = MTL::CullModeNone;
260260

261+
MTL::Size MeshThreadsPerThreadgroup{1, 1, 1};
262+
MTL::Size ObjectThreadsPerThreadgroup{1, 1, 1};
263+
261264
MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
262265
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
263266
MTL::ComputePipelineState *ComputePipeline,
@@ -270,11 +273,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
270273
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
271274
MTL::RenderPipelineState *RenderPipeline,
272275
MTL::DepthStencilState *DepthStencilState,
273-
MTL::CullMode CullMode)
276+
MTL::CullMode CullMode,
277+
MTL::Size MeshThreadsPerThreadgroup = {1, 1, 1},
278+
MTL::Size ObjectThreadsPerThreadgroup = {1, 1, 1})
274279
: offloadtest::PipelineState(GPUAPI::Metal), Name(Name),
275280
RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
276281
RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
277-
CullMode(CullMode) {}
282+
CullMode(CullMode),
283+
MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
284+
ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}
278285

279286
~MTLPipelineState() override {
280287
if (ComputePipeline)
@@ -714,13 +721,30 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
714721
llvm::Error dispatchMesh(const offloadtest::PipelineState &PSO,
715722
uint32_t GroupCountX, uint32_t GroupCountY,
716723
uint32_t GroupCountZ) override {
717-
(void)PSO;
718-
(void)GroupCountX;
719-
(void)GroupCountY;
720-
(void)GroupCountZ;
724+
if (!ViewportSet)
725+
return llvm::createStringError(std::errc::invalid_argument,
726+
"Viewport must be set before drawing.");
727+
if (!ScissorSet)
728+
return llvm::createStringError(std::errc::invalid_argument,
729+
"Scissor must be set before drawing.");
721730

722-
return llvm::createStringError(
723-
"dispatchMesh is unimplemented in the Metal backend.");
731+
const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
732+
if (!MTLPSO.RenderPipeline)
733+
return llvm::createStringError(
734+
std::errc::invalid_argument,
735+
"PipelineState bound to dispatchMesh() is not a render pipeline.");
736+
RenderEnc->setRenderPipelineState(MTLPSO.RenderPipeline);
737+
if (MTLPSO.DepthStencilState)
738+
RenderEnc->setDepthStencilState(MTLPSO.DepthStencilState);
739+
RenderEnc->setCullMode(MTLPSO.CullMode);
740+
// Match the DX/VK convention (CCW = front) hardcoded in those backends.
741+
RenderEnc->setFrontFacingWinding(MTL::WindingCounterClockwise);
742+
743+
RenderEnc->drawMeshThreadgroups(
744+
MTL::Size(GroupCountX, GroupCountY, GroupCountZ),
745+
MTLPSO.ObjectThreadsPerThreadgroup, MTLPSO.MeshThreadsPerThreadgroup);
746+
747+
return llvm::Error::success();
724748
}
725749

726750
void endEncodingImpl() override {
@@ -1405,12 +1429,23 @@ class MTLDevice : public offloadtest::Device {
14051429
Scissor.Height = static_cast<uint32_t>(Height);
14061430
Encoder.setScissor(Scissor);
14071431

1408-
if (IS.VB)
1409-
Encoder.setVertexBuffer(0, IS.VB.get(), 0, P.Bindings.getVertexStride());
1432+
if (P.isTraditionalRaster()) {
1433+
if (IS.VB)
1434+
Encoder.setVertexBuffer(0, IS.VB.get(), 0,
1435+
P.Bindings.getVertexStride());
1436+
1437+
if (auto Err =
1438+
Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
1439+
/*InstanceCount=*/1))
1440+
return Err;
1441+
} else {
1442+
if (auto Err = Encoder.dispatchMesh(
1443+
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
1444+
P.DispatchParameters.DispatchGroupCount[1],
1445+
P.DispatchParameters.DispatchGroupCount[2]))
1446+
return Err;
1447+
}
14101448

1411-
if (auto Err = Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
1412-
/*InstanceCount=*/1))
1413-
return Err;
14141449
Encoder.endEncoding();
14151450

14161451
// Blit the render target into the readback buffer for CPU access.
@@ -1769,6 +1804,136 @@ class MTLDevice : public offloadtest::Device {
17691804
DSState, MTL::CullModeNone);
17701805
}
17711806

1807+
llvm::Expected<std::unique_ptr<PipelineState>>
1808+
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1809+
llvm::ArrayRef<Format> RTFormats,
1810+
std::optional<Format> DSFormat,
1811+
std::optional<ShaderContainer> AS, ShaderContainer MS,
1812+
std::optional<ShaderContainer> PS) {
1813+
IRRootSignaturePtr RootSig;
1814+
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
1815+
if (auto Err = createRootSignature(BindingsDesc, /*IsGraphics=*/true,
1816+
RootSig, ArgBuffer))
1817+
return Err;
1818+
1819+
NS::Error *Error = nullptr;
1820+
auto compileStage = [&](Stages Stage, const ShaderContainer &SC,
1821+
llvm::StringRef RoleName, MetalIR &OutIR,
1822+
MTLPtr<MTL::Library> &OutLib,
1823+
MTLPtr<MTL::Function> &OutFn) -> llvm::Error {
1824+
auto IROrErr =
1825+
convertToMetalIR(Stage, /*IsGraphics=*/true, RootSig.get(), SC);
1826+
if (!IROrErr)
1827+
return IROrErr.takeError();
1828+
OutIR = std::move(*IROrErr);
1829+
1830+
dispatch_data_t Data = IRMetalLibGetBytecodeData(OutIR.Binary.get());
1831+
NS::Error *Err = nullptr;
1832+
OutLib = MTLPtr<MTL::Library>(Device->newLibrary(Data, &Err));
1833+
if (Err)
1834+
return toError(Err);
1835+
1836+
OutFn = MTLPtr<MTL::Function>(OutLib->newFunction(
1837+
NS::String::string(SC.EntryPoint.c_str(), NS::UTF8StringEncoding)));
1838+
if (!OutFn)
1839+
return llvm::createStringError(
1840+
std::errc::invalid_argument,
1841+
"Failed to find %s entry point '%s' in Metal library.",
1842+
RoleName.data(), SC.EntryPoint.c_str());
1843+
return llvm::Error::success();
1844+
};
1845+
1846+
MetalIR MSIR;
1847+
MTLPtr<MTL::Library> MSLib;
1848+
MTLPtr<MTL::Function> MSFn;
1849+
if (auto Err = compileStage(Stages::Mesh, MS, "mesh", MSIR, MSLib, MSFn))
1850+
return Err;
1851+
1852+
MetalIR ASIR;
1853+
MTLPtr<MTL::Library> ASLib;
1854+
MTLPtr<MTL::Function> ASFn;
1855+
if (AS) {
1856+
if (auto Err = compileStage(Stages::Amplification, *AS, "amplification",
1857+
ASIR, ASLib, ASFn))
1858+
return Err;
1859+
}
1860+
1861+
MetalIR PSIR;
1862+
MTLPtr<MTL::Library> PSLib;
1863+
MTLPtr<MTL::Function> PSFn;
1864+
if (PS) {
1865+
if (auto Err =
1866+
compileStage(Stages::Pixel, *PS, "fragment", PSIR, PSLib, PSFn))
1867+
return Err;
1868+
}
1869+
1870+
MTL::MeshRenderPipelineDescriptor *Desc =
1871+
MTL::MeshRenderPipelineDescriptor::alloc()->init();
1872+
auto DescScope = llvm::scope_exit([&] { Desc->release(); });
1873+
1874+
Desc->setMeshFunction(MSFn.get());
1875+
if (ASFn)
1876+
Desc->setObjectFunction(ASFn.get());
1877+
if (PSFn)
1878+
Desc->setFragmentFunction(PSFn.get());
1879+
1880+
for (size_t I = 0; I < RTFormats.size(); ++I) {
1881+
MTL::RenderPipelineColorAttachmentDescriptor *RPCA =
1882+
MTL::RenderPipelineColorAttachmentDescriptor::alloc()->init();
1883+
RPCA->setPixelFormat(getMetalPixelFormat(RTFormats[I]));
1884+
Desc->colorAttachments()->setObject(RPCA, I);
1885+
RPCA->release();
1886+
}
1887+
1888+
if (DSFormat) {
1889+
const MTL::PixelFormat DSPixelFormat = getMetalPixelFormat(*DSFormat);
1890+
Desc->setDepthAttachmentPixelFormat(DSPixelFormat);
1891+
if (isStencilFormat(*DSFormat))
1892+
Desc->setStencilAttachmentPixelFormat(DSPixelFormat);
1893+
}
1894+
1895+
MTL::RenderPipelineState *PSO = Device->newRenderPipelineState(
1896+
Desc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
1897+
if (Error)
1898+
return toError(Error);
1899+
1900+
MTL::DepthStencilDescriptor *DSDesc =
1901+
MTL::DepthStencilDescriptor::alloc()->init();
1902+
DSDesc->setDepthCompareFunction(MTL::CompareFunctionLess);
1903+
DSDesc->setDepthWriteEnabled(true);
1904+
MTL::DepthStencilState *DSState = Device->newDepthStencilState(DSDesc);
1905+
DSDesc->release();
1906+
1907+
// Pull threads-per-threadgroup from shader reflection.
1908+
MTL::Size MeshTGSize(1, 1, 1);
1909+
{
1910+
IRVersionedMSInfo MSInfo;
1911+
if (IRShaderReflectionCopyMeshInfo(MSIR.Reflection.get(),
1912+
IRReflectionVersion_1_0, &MSInfo)) {
1913+
MeshTGSize = MTL::Size(MSInfo.info_1_0.num_threads[0],
1914+
MSInfo.info_1_0.num_threads[1],
1915+
MSInfo.info_1_0.num_threads[2]);
1916+
}
1917+
IRShaderReflectionReleaseMeshInfo(&MSInfo);
1918+
}
1919+
1920+
MTL::Size ObjectTGSize(1, 1, 1);
1921+
if (AS) {
1922+
IRVersionedASInfo ASInfo;
1923+
if (IRShaderReflectionCopyAmplificationInfo(
1924+
ASIR.Reflection.get(), IRReflectionVersion_1_0, &ASInfo)) {
1925+
ObjectTGSize = MTL::Size(ASInfo.info_1_0.num_threads[0],
1926+
ASInfo.info_1_0.num_threads[1],
1927+
ASInfo.info_1_0.num_threads[2]);
1928+
}
1929+
IRShaderReflectionReleaseAmplificationInfo(&ASInfo);
1930+
}
1931+
1932+
return std::make_unique<MTLPipelineState>(
1933+
Name, std::move(RootSig), std::move(ArgBuffer), PSO, DSState,
1934+
MTL::CullModeNone, MeshTGSize, ObjectTGSize);
1935+
}
1936+
17721937
llvm::Error executeProgram(Pipeline &P) override {
17731938
InvocationState IS;
17741939

@@ -1817,40 +1982,68 @@ class MTLDevice : public offloadtest::Device {
18171982

18181983
if (auto Err = createComputeCommands(P, IS))
18191984
return Err;
1820-
} else {
1821-
TraditionalRasterPipelineCreateDesc PipelineDesc = {};
1822-
PipelineDesc.Topology = P.Bindings.Topology;
1823-
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
1824-
for (auto &Shader : P.Shaders) {
1825-
ShaderContainer SC = {};
1826-
SC.EntryPoint = Shader.Entry;
1827-
SC.Shader = Shader.Shader.get();
1828-
PipelineDesc.setShader(Shader.Stage, std::move(SC));
1829-
}
1830-
1831-
for (auto &Attr : P.Bindings.VertexAttributes) {
1832-
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
1833-
if (!FormatOrErr)
1834-
return FormatOrErr.takeError();
1835-
1836-
InputLayoutDesc Layout = {};
1837-
Layout.Name = Attr.Name;
1838-
Layout.Fmt = *FormatOrErr;
1839-
Layout.OffsetInBytes = Attr.Offset;
1840-
PipelineDesc.InputLayout.push_back(Layout);
1841-
}
1842-
1985+
} else if (P.isRaster()) {
18431986
auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format,
18441987
P.Bindings.RTargetBufferPtr->Channels);
1988+
1989+
llvm::SmallVector<Format> RTFormats;
18451990
if (!FormatOrErr)
18461991
return FormatOrErr.takeError();
1847-
PipelineDesc.RTFormats.push_back(*FormatOrErr);
1992+
RTFormats.push_back(*FormatOrErr);
1993+
1994+
if (P.isTraditionalRaster()) {
1995+
TraditionalRasterPipelineCreateDesc PipelineDesc = {};
1996+
PipelineDesc.Topology = P.Bindings.Topology;
1997+
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
1998+
PipelineDesc.RTFormats = RTFormats;
1999+
for (auto &Shader : P.Shaders) {
2000+
ShaderContainer SC = {};
2001+
SC.EntryPoint = Shader.Entry;
2002+
SC.Shader = Shader.Shader.get();
2003+
PipelineDesc.setShader(Shader.Stage, std::move(SC));
2004+
}
18482005

1849-
auto PipelineStateOrErr = createTraditionalRasterPipeline(
1850-
"Graphics Pipeline State", Bindings, PipelineDesc);
1851-
if (!PipelineStateOrErr)
1852-
return PipelineStateOrErr.takeError();
1853-
IS.Pipeline = std::move(*PipelineStateOrErr);
2006+
for (auto &Attr : P.Bindings.VertexAttributes) {
2007+
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
2008+
if (!FormatOrErr)
2009+
return FormatOrErr.takeError();
2010+
2011+
InputLayoutDesc Layout = {};
2012+
Layout.Name = Attr.Name;
2013+
Layout.Fmt = *FormatOrErr;
2014+
Layout.OffsetInBytes = Attr.Offset;
2015+
PipelineDesc.InputLayout.push_back(Layout);
2016+
}
2017+
2018+
auto PipelineStateOrErr = createTraditionalRasterPipeline(
2019+
"Graphics Pipeline State", Bindings, PipelineDesc);
2020+
if (!PipelineStateOrErr)
2021+
return PipelineStateOrErr.takeError();
2022+
IS.Pipeline = std::move(*PipelineStateOrErr);
2023+
} else if (P.isMeshShaderRaster()) {
2024+
std::optional<ShaderContainer> AS;
2025+
ShaderContainer MS = {};
2026+
std::optional<ShaderContainer> PS;
2027+
for (auto &Shader : P.Shaders) {
2028+
ShaderContainer SC = {};
2029+
SC.EntryPoint = Shader.Entry;
2030+
SC.Shader = Shader.Shader.get();
2031+
if (Shader.Stage == Stages::Amplification)
2032+
AS = std::move(SC);
2033+
else if (Shader.Stage == Stages::Mesh)
2034+
MS = std::move(SC);
2035+
else if (Shader.Stage == Stages::Pixel)
2036+
PS = std::move(SC);
2037+
}
2038+
2039+
auto PipelineStateOrErr =
2040+
createPipelineAsMsPs("Mesh Shader Pipeline State", Bindings,
2041+
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
2042+
if (!PipelineStateOrErr)
2043+
return PipelineStateOrErr.takeError();
2044+
IS.Pipeline = std::move(*PipelineStateOrErr);
2045+
llvm::outs() << "Mesh Shader Pipeline created.\n";
2046+
}
18542047

18552048
ColorAttachmentFormatDesc ColorAttachment = {};
18562049
ColorAttachment.Fmt = *FormatOrErr;
@@ -1892,7 +2085,13 @@ class MTLDevice : public offloadtest::Device {
18922085
virtual ~MTLDevice() {};
18932086

18942087
private:
1895-
void queryCapabilities() {}
2088+
void queryCapabilities() {
2089+
// GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
2090+
const bool MeshShaderSupported =
2091+
Device->supportsFamily(MTL::GPUFamilyMetal3);
2092+
Caps.insert(std::make_pair(
2093+
"MeshShader", makeCapability<bool>("MeshShader", MeshShaderSupported)));
2094+
}
18962095
};
18972096
} // namespace
18982097

test/lit.cfg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def setDeviceFeatures(config, device, compiler):
158158
config.available_features.add("Int16")
159159
config.available_features.add("Int64")
160160
config.available_features.add("Half")
161+
if device["Features"].get("MeshShader", False):
162+
config.available_features.add("MeshShader")
161163

162164
if device["API"] == "Vulkan":
163165
if device["Features"].get("shaderInt16", False):

0 commit comments

Comments
 (0)