Skip to content

Commit f1bc124

Browse files
Add Mesh Shader support to Metal.
1 parent 997b1c7 commit f1bc124

2 files changed

Lines changed: 243 additions & 42 deletions

File tree

lib/API/MTL/MTLDevice.cpp

Lines changed: 241 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
254254
MTL::DepthStencilState *DepthStencilState = nullptr;
255255
MTL::CullMode CullMode = MTL::CullModeNone;
256256

257+
MTL::Size MeshThreadsPerThreadgroup{1, 1, 1};
258+
MTL::Size ObjectThreadsPerThreadgroup{1, 1, 1};
259+
257260
MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
258261
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
259262
MTL::ComputePipelineState *ComputePipeline,
@@ -266,11 +269,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
266269
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
267270
MTL::RenderPipelineState *RenderPipeline,
268271
MTL::DepthStencilState *DepthStencilState,
269-
MTL::CullMode CullMode)
272+
MTL::CullMode CullMode,
273+
MTL::Size MeshThreadsPerThreadgroup = {1, 1, 1},
274+
MTL::Size ObjectThreadsPerThreadgroup = {1, 1, 1})
270275
: offloadtest::PipelineState(GPUAPI::Metal), Name(Name),
271276
RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
272277
RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
273-
CullMode(CullMode) {}
278+
CullMode(CullMode),
279+
MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
280+
ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}
274281

275282
~MTLPipelineState() override {
276283
if (ComputePipeline)
@@ -710,13 +717,30 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
710717
llvm::Error dispatchMesh(const offloadtest::PipelineState &PSO,
711718
uint32_t GroupCountX, uint32_t GroupCountY,
712719
uint32_t GroupCountZ) override {
713-
(void)PSO;
714-
(void)GroupCountX;
715-
(void)GroupCountY;
716-
(void)GroupCountZ;
720+
if (!ViewportSet)
721+
return llvm::createStringError(std::errc::invalid_argument,
722+
"Viewport must be set before drawing.");
723+
if (!ScissorSet)
724+
return llvm::createStringError(std::errc::invalid_argument,
725+
"Scissor must be set before drawing.");
717726

718-
return llvm::createStringError(
719-
"dispatchMesh is unimplemented in the Metal backend.");
727+
const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
728+
if (!MTLPSO.RenderPipeline)
729+
return llvm::createStringError(
730+
std::errc::invalid_argument,
731+
"PipelineState bound to dispatchMesh() is not a render pipeline.");
732+
RenderEnc->setRenderPipelineState(MTLPSO.RenderPipeline);
733+
if (MTLPSO.DepthStencilState)
734+
RenderEnc->setDepthStencilState(MTLPSO.DepthStencilState);
735+
RenderEnc->setCullMode(MTLPSO.CullMode);
736+
// Match the DX/VK convention (CCW = front) hardcoded in those backends.
737+
RenderEnc->setFrontFacingWinding(MTL::WindingCounterClockwise);
738+
739+
RenderEnc->drawMeshThreadgroups(
740+
MTL::Size(GroupCountX, GroupCountY, GroupCountZ),
741+
MTLPSO.ObjectThreadsPerThreadgroup, MTLPSO.MeshThreadsPerThreadgroup);
742+
743+
return llvm::Error::success();
720744
}
721745

722746
void endEncodingImpl() override {
@@ -1401,12 +1425,23 @@ class MTLDevice : public offloadtest::Device {
14011425
Scissor.Height = static_cast<uint32_t>(Height);
14021426
Encoder.setScissor(Scissor);
14031427

1404-
if (IS.VB)
1405-
Encoder.setVertexBuffer(0, IS.VB.get(), 0, P.Bindings.getVertexStride());
1428+
if (P.isTraditionalRaster()) {
1429+
if (IS.VB)
1430+
Encoder.setVertexBuffer(0, IS.VB.get(), 0,
1431+
P.Bindings.getVertexStride());
1432+
1433+
if (auto Err =
1434+
Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
1435+
/*InstanceCount=*/1))
1436+
return Err;
1437+
} else {
1438+
if (auto Err = Encoder.dispatchMesh(
1439+
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
1440+
P.DispatchParameters.DispatchGroupCount[1],
1441+
P.DispatchParameters.DispatchGroupCount[2]))
1442+
return Err;
1443+
}
14061444

1407-
if (auto Err = Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
1408-
/*InstanceCount=*/1))
1409-
return Err;
14101445
Encoder.endEncoding();
14111446

14121447
// Blit the render target into the readback buffer for CPU access.
@@ -1760,6 +1795,136 @@ class MTLDevice : public offloadtest::Device {
17601795
DSState, MTL::CullModeNone);
17611796
}
17621797

1798+
llvm::Expected<std::unique_ptr<PipelineState>>
1799+
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1800+
llvm::ArrayRef<Format> RTFormats,
1801+
std::optional<Format> DSFormat,
1802+
std::optional<ShaderContainer> AS, ShaderContainer MS,
1803+
std::optional<ShaderContainer> PS) {
1804+
IRRootSignaturePtr RootSig;
1805+
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
1806+
if (auto Err = createRootSignature(BindingsDesc, /*IsGraphics=*/true,
1807+
RootSig, ArgBuffer))
1808+
return Err;
1809+
1810+
NS::Error *Error = nullptr;
1811+
auto compileStage = [&](Stages Stage, const ShaderContainer &SC,
1812+
llvm::StringRef RoleName, MetalIR &OutIR,
1813+
MTLPtr<MTL::Library> &OutLib,
1814+
MTLPtr<MTL::Function> &OutFn) -> llvm::Error {
1815+
auto IROrErr =
1816+
convertToMetalIR(Stage, /*IsGraphics=*/true, RootSig.get(), SC);
1817+
if (!IROrErr)
1818+
return IROrErr.takeError();
1819+
OutIR = std::move(*IROrErr);
1820+
1821+
dispatch_data_t Data = IRMetalLibGetBytecodeData(OutIR.Binary.get());
1822+
NS::Error *Err = nullptr;
1823+
OutLib = MTLPtr<MTL::Library>(Device->newLibrary(Data, &Err));
1824+
if (Err)
1825+
return toError(Err);
1826+
1827+
OutFn = MTLPtr<MTL::Function>(OutLib->newFunction(
1828+
NS::String::string(SC.EntryPoint.c_str(), NS::UTF8StringEncoding)));
1829+
if (!OutFn)
1830+
return llvm::createStringError(
1831+
std::errc::invalid_argument,
1832+
"Failed to find %s entry point '%s' in Metal library.",
1833+
RoleName.data(), SC.EntryPoint.c_str());
1834+
return llvm::Error::success();
1835+
};
1836+
1837+
MetalIR MSIR;
1838+
MTLPtr<MTL::Library> MSLib;
1839+
MTLPtr<MTL::Function> MSFn;
1840+
if (auto Err = compileStage(Stages::Mesh, MS, "mesh", MSIR, MSLib, MSFn))
1841+
return Err;
1842+
1843+
MetalIR ASIR;
1844+
MTLPtr<MTL::Library> ASLib;
1845+
MTLPtr<MTL::Function> ASFn;
1846+
if (AS) {
1847+
if (auto Err = compileStage(Stages::Amplification, *AS, "amplification",
1848+
ASIR, ASLib, ASFn))
1849+
return Err;
1850+
}
1851+
1852+
MetalIR PSIR;
1853+
MTLPtr<MTL::Library> PSLib;
1854+
MTLPtr<MTL::Function> PSFn;
1855+
if (PS) {
1856+
if (auto Err =
1857+
compileStage(Stages::Pixel, *PS, "fragment", PSIR, PSLib, PSFn))
1858+
return Err;
1859+
}
1860+
1861+
MTL::MeshRenderPipelineDescriptor *Desc =
1862+
MTL::MeshRenderPipelineDescriptor::alloc()->init();
1863+
auto DescScope = llvm::scope_exit([&] { Desc->release(); });
1864+
1865+
Desc->setMeshFunction(MSFn.get());
1866+
if (ASFn)
1867+
Desc->setObjectFunction(ASFn.get());
1868+
if (PSFn)
1869+
Desc->setFragmentFunction(PSFn.get());
1870+
1871+
for (size_t I = 0; I < RTFormats.size(); ++I) {
1872+
MTL::RenderPipelineColorAttachmentDescriptor *RPCA =
1873+
MTL::RenderPipelineColorAttachmentDescriptor::alloc()->init();
1874+
RPCA->setPixelFormat(getMetalPixelFormat(RTFormats[I]));
1875+
Desc->colorAttachments()->setObject(RPCA, I);
1876+
RPCA->release();
1877+
}
1878+
1879+
if (DSFormat) {
1880+
const MTL::PixelFormat DSPixelFormat = getMetalPixelFormat(*DSFormat);
1881+
Desc->setDepthAttachmentPixelFormat(DSPixelFormat);
1882+
if (isStencilFormat(*DSFormat))
1883+
Desc->setStencilAttachmentPixelFormat(DSPixelFormat);
1884+
}
1885+
1886+
MTL::RenderPipelineState *PSO = Device->newRenderPipelineState(
1887+
Desc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
1888+
if (Error)
1889+
return toError(Error);
1890+
1891+
MTL::DepthStencilDescriptor *DSDesc =
1892+
MTL::DepthStencilDescriptor::alloc()->init();
1893+
DSDesc->setDepthCompareFunction(MTL::CompareFunctionLess);
1894+
DSDesc->setDepthWriteEnabled(true);
1895+
MTL::DepthStencilState *DSState = Device->newDepthStencilState(DSDesc);
1896+
DSDesc->release();
1897+
1898+
// Pull threads-per-threadgroup from shader reflection.
1899+
MTL::Size MeshTGSize(1, 1, 1);
1900+
{
1901+
IRVersionedMSInfo MSInfo;
1902+
if (IRShaderReflectionCopyMeshInfo(MSIR.Reflection.get(),
1903+
IRReflectionVersion_1_0, &MSInfo)) {
1904+
MeshTGSize = MTL::Size(MSInfo.info_1_0.num_threads[0],
1905+
MSInfo.info_1_0.num_threads[1],
1906+
MSInfo.info_1_0.num_threads[2]);
1907+
}
1908+
IRShaderReflectionReleaseMeshInfo(&MSInfo);
1909+
}
1910+
1911+
MTL::Size ObjectTGSize(1, 1, 1);
1912+
if (AS) {
1913+
IRVersionedASInfo ASInfo;
1914+
if (IRShaderReflectionCopyAmplificationInfo(
1915+
ASIR.Reflection.get(), IRReflectionVersion_1_0, &ASInfo)) {
1916+
ObjectTGSize = MTL::Size(ASInfo.info_1_0.num_threads[0],
1917+
ASInfo.info_1_0.num_threads[1],
1918+
ASInfo.info_1_0.num_threads[2]);
1919+
}
1920+
IRShaderReflectionReleaseAmplificationInfo(&ASInfo);
1921+
}
1922+
1923+
return std::make_unique<MTLPipelineState>(
1924+
Name, std::move(RootSig), std::move(ArgBuffer), PSO, DSState,
1925+
MTL::CullModeNone, MeshTGSize, ObjectTGSize);
1926+
}
1927+
17631928
llvm::Error executeProgram(Pipeline &P) override {
17641929
InvocationState IS;
17651930

@@ -1808,40 +1973,68 @@ class MTLDevice : public offloadtest::Device {
18081973

18091974
if (auto Err = createComputeCommands(P, IS))
18101975
return Err;
1811-
} else {
1812-
TraditionalRasterPipelineCreateDesc PipelineDesc = {};
1813-
PipelineDesc.Topology = P.Bindings.Topology;
1814-
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
1815-
for (auto &Shader : P.Shaders) {
1816-
ShaderContainer SC = {};
1817-
SC.EntryPoint = Shader.Entry;
1818-
SC.Shader = Shader.Shader.get();
1819-
PipelineDesc.setShader(Shader.Stage, std::move(SC));
1820-
}
1821-
1822-
for (auto &Attr : P.Bindings.VertexAttributes) {
1823-
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
1824-
if (!FormatOrErr)
1825-
return FormatOrErr.takeError();
1826-
1827-
InputLayoutDesc Layout = {};
1828-
Layout.Name = Attr.Name;
1829-
Layout.Fmt = *FormatOrErr;
1830-
Layout.OffsetInBytes = Attr.Offset;
1831-
PipelineDesc.InputLayout.push_back(Layout);
1832-
}
1833-
1976+
} else if (P.isRaster()) {
18341977
auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format,
18351978
P.Bindings.RTargetBufferPtr->Channels);
18361979
if (!FormatOrErr)
18371980
return FormatOrErr.takeError();
18381981
PipelineDesc.RTFormats.push_back(*FormatOrErr);
18391982

1840-
auto PipelineStateOrErr = createTraditionalRasterPipeline(
1841-
"Graphics Pipeline State", Bindings, PipelineDesc);
1842-
if (!PipelineStateOrErr)
1843-
return PipelineStateOrErr.takeError();
1844-
IS.Pipeline = std::move(*PipelineStateOrErr);
1983+
llvm::SmallVector<Format> RTFormats;
1984+
RTFormats.push_back(*FormatOrErr);
1985+
1986+
if (P.isTraditionalRaster()) {
1987+
TraditionalRasterPipelineCreateDesc PipelineDesc = {};
1988+
PipelineDesc.Topology = P.Bindings.Topology;
1989+
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
1990+
for (auto &Shader : P.Shaders) {
1991+
ShaderContainer SC = {};
1992+
SC.EntryPoint = Shader.Entry;
1993+
SC.Shader = Shader.Shader.get();
1994+
PipelineDesc.setShader(Shader.Stage, std::move(SC));
1995+
}
1996+
1997+
for (auto &Attr : P.Bindings.VertexAttributes) {
1998+
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
1999+
if (!FormatOrErr)
2000+
return FormatOrErr.takeError();
2001+
2002+
InputLayoutDesc Layout = {};
2003+
Layout.Name = Attr.Name;
2004+
Layout.Fmt = *FormatOrErr;
2005+
Layout.OffsetInBytes = Attr.Offset;
2006+
PipelineDesc.InputLayout.push_back(Layout);
2007+
}
2008+
2009+
auto PipelineStateOrErr = createTraditionalRasterPipeline(
2010+
"Graphics Pipeline State", Bindings, PipelineDesc);
2011+
if (!PipelineStateOrErr)
2012+
return PipelineStateOrErr.takeError();
2013+
IS.Pipeline = std::move(*PipelineStateOrErr);
2014+
} else if (P.isMeshShaderRaster()) {
2015+
std::optional<ShaderContainer> AS;
2016+
ShaderContainer MS = {};
2017+
std::optional<ShaderContainer> PS;
2018+
for (auto &Shader : P.Shaders) {
2019+
ShaderContainer SC = {};
2020+
SC.EntryPoint = Shader.Entry;
2021+
SC.Shader = Shader.Shader.get();
2022+
if (Shader.Stage == Stages::Amplification)
2023+
AS = std::move(SC);
2024+
else if (Shader.Stage == Stages::Mesh)
2025+
MS = std::move(SC);
2026+
else if (Shader.Stage == Stages::Pixel)
2027+
PS = std::move(SC);
2028+
}
2029+
2030+
auto PipelineStateOrErr =
2031+
createPipelineAsMsPs("Mesh Shader Pipeline State", Bindings,
2032+
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
2033+
if (!PipelineStateOrErr)
2034+
return PipelineStateOrErr.takeError();
2035+
IS.Pipeline = std::move(*PipelineStateOrErr);
2036+
llvm::outs() << "Mesh Shader Pipeline created.\n";
2037+
}
18452038

18462039
ColorAttachmentFormatDesc ColorAttachment = {};
18472040
ColorAttachment.Fmt = *FormatOrErr;
@@ -1883,7 +2076,13 @@ class MTLDevice : public offloadtest::Device {
18832076
virtual ~MTLDevice() {};
18842077

18852078
private:
1886-
void queryCapabilities() {}
2079+
void queryCapabilities() {
2080+
// GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
2081+
const bool MeshShaderSupported =
2082+
Device->supportsFamily(MTL::GPUFamilyMetal3);
2083+
Caps.insert(std::make_pair(
2084+
"MeshShader", makeCapability<bool>("MeshShader", MeshShaderSupported)));
2085+
}
18872086
};
18882087
} // namespace
18892088

test/lit.cfg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def setDeviceFeatures(config, device, compiler):
158158
config.available_features.add("Int16")
159159
config.available_features.add("Int64")
160160
config.available_features.add("Half")
161+
if device["Features"].get("MeshShader", False):
162+
config.available_features.add("MeshShader")
161163

162164
if device["API"] == "Vulkan":
163165
if device["Features"].get("shaderInt16", False):

0 commit comments

Comments
 (0)