Skip to content

Commit 1826456

Browse files
Add Mesh Shader support to Metal.
1 parent 7cd5ae1 commit 1826456

2 files changed

Lines changed: 245 additions & 46 deletions

File tree

lib/API/MTL/MTLDevice.cpp

Lines changed: 243 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
248248
MTL::DepthStencilState *DepthStencilState = nullptr;
249249
MTL::CullMode CullMode = MTL::CullModeNone;
250250

251+
MTL::Size MeshThreadsPerThreadgroup{1, 1, 1};
252+
MTL::Size ObjectThreadsPerThreadgroup{1, 1, 1};
253+
251254
MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
252255
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
253256
IRShaderReflectionPtr Reflection,
@@ -260,11 +263,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
260263
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
261264
MTL::RenderPipelineState *RenderPipeline,
262265
MTL::DepthStencilState *DepthStencilState,
263-
MTL::CullMode CullMode)
266+
MTL::CullMode CullMode,
267+
MTL::Size MeshThreadsPerThreadgroup = {1, 1, 1},
268+
MTL::Size ObjectThreadsPerThreadgroup = {1, 1, 1})
264269
: offloadtest::PipelineState(GPUAPI::Metal), Name(Name),
265270
RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
266271
RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
267-
CullMode(CullMode) {}
272+
CullMode(CullMode),
273+
MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
274+
ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}
268275

269276
~MTLPipelineState() override {
270277
if (ComputePipeline)
@@ -708,13 +715,30 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
708715
llvm::Error dispatchMesh(const offloadtest::PipelineState &PSO,
709716
uint32_t GroupCountX, uint32_t GroupCountY,
710717
uint32_t GroupCountZ) override {
711-
(void)PSO;
712-
(void)GroupCountX;
713-
(void)GroupCountY;
714-
(void)GroupCountZ;
718+
if (!ViewportSet)
719+
return llvm::createStringError(std::errc::invalid_argument,
720+
"Viewport must be set before drawing.");
721+
if (!ScissorSet)
722+
return llvm::createStringError(std::errc::invalid_argument,
723+
"Scissor must be set before drawing.");
715724

716-
return llvm::createStringError(
717-
"dispatchMesh is unimplemented in the Metal backend.");
725+
const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
726+
if (!MTLPSO.RenderPipeline)
727+
return llvm::createStringError(
728+
std::errc::invalid_argument,
729+
"PipelineState bound to dispatchMesh() is not a render pipeline.");
730+
RenderEnc->setRenderPipelineState(MTLPSO.RenderPipeline);
731+
if (MTLPSO.DepthStencilState)
732+
RenderEnc->setDepthStencilState(MTLPSO.DepthStencilState);
733+
RenderEnc->setCullMode(MTLPSO.CullMode);
734+
// Match the DX/VK convention (CCW = front) hardcoded in those backends.
735+
RenderEnc->setFrontFacingWinding(MTL::WindingCounterClockwise);
736+
737+
RenderEnc->drawMeshThreadgroups(
738+
MTL::Size(GroupCountX, GroupCountY, GroupCountZ),
739+
MTLPSO.ObjectThreadsPerThreadgroup, MTLPSO.MeshThreadsPerThreadgroup);
740+
741+
return llvm::Error::success();
718742
}
719743

720744
void endEncodingImpl() override {
@@ -1413,12 +1437,23 @@ class MTLDevice : public offloadtest::Device {
14131437
Scissor.Height = static_cast<uint32_t>(Height);
14141438
Encoder.setScissor(Scissor);
14151439

1416-
if (IS.VB)
1417-
Encoder.setVertexBuffer(0, IS.VB.get(), 0, P.Bindings.getVertexStride());
1440+
if (P.isTraditionalRaster()) {
1441+
if (IS.VB)
1442+
Encoder.setVertexBuffer(0, IS.VB.get(), 0,
1443+
P.Bindings.getVertexStride());
1444+
1445+
if (auto Err =
1446+
Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
1447+
/*InstanceCount=*/1))
1448+
return Err;
1449+
} else {
1450+
if (auto Err = Encoder.dispatchMesh(
1451+
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
1452+
P.DispatchParameters.DispatchGroupCount[1],
1453+
P.DispatchParameters.DispatchGroupCount[2]))
1454+
return Err;
1455+
}
14181456

1419-
if (auto Err = Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
1420-
/*InstanceCount=*/1))
1421-
return Err;
14221457
Encoder.endEncoding();
14231458

14241459
// Blit the render target into the readback buffer for CPU access.
@@ -1749,6 +1784,136 @@ class MTLDevice : public offloadtest::Device {
17491784
DSState, MTL::CullModeNone);
17501785
}
17511786

1787+
llvm::Expected<std::unique_ptr<PipelineState>>
1788+
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1789+
llvm::ArrayRef<Format> RTFormats,
1790+
std::optional<Format> DSFormat,
1791+
std::optional<ShaderContainer> AS, ShaderContainer MS,
1792+
std::optional<ShaderContainer> PS) {
1793+
IRRootSignaturePtr RootSig;
1794+
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
1795+
if (auto Err = createRootSignature(BindingsDesc, /*IsGraphics=*/true,
1796+
RootSig, ArgBuffer))
1797+
return Err;
1798+
1799+
NS::Error *Error = nullptr;
1800+
auto compileStage = [&](Stages Stage, const ShaderContainer &SC,
1801+
llvm::StringRef RoleName, MetalIR &OutIR,
1802+
MTLPtr<MTL::Library> &OutLib,
1803+
MTLPtr<MTL::Function> &OutFn) -> llvm::Error {
1804+
auto IROrErr =
1805+
convertToMetalIR(Stage, /*IsGraphics=*/true, RootSig.get(), SC);
1806+
if (!IROrErr)
1807+
return IROrErr.takeError();
1808+
OutIR = std::move(*IROrErr);
1809+
1810+
dispatch_data_t Data = IRMetalLibGetBytecodeData(OutIR.Binary.get());
1811+
NS::Error *Err = nullptr;
1812+
OutLib = MTLPtr<MTL::Library>(Device->newLibrary(Data, &Err));
1813+
if (Err)
1814+
return toError(Err);
1815+
1816+
OutFn = MTLPtr<MTL::Function>(OutLib->newFunction(
1817+
NS::String::string(SC.EntryPoint.c_str(), NS::UTF8StringEncoding)));
1818+
if (!OutFn)
1819+
return llvm::createStringError(
1820+
std::errc::invalid_argument,
1821+
"Failed to find %s entry point '%s' in Metal library.",
1822+
RoleName.data(), SC.EntryPoint.c_str());
1823+
return llvm::Error::success();
1824+
};
1825+
1826+
MetalIR MSIR;
1827+
MTLPtr<MTL::Library> MSLib;
1828+
MTLPtr<MTL::Function> MSFn;
1829+
if (auto Err = compileStage(Stages::Mesh, MS, "mesh", MSIR, MSLib, MSFn))
1830+
return Err;
1831+
1832+
MetalIR ASIR;
1833+
MTLPtr<MTL::Library> ASLib;
1834+
MTLPtr<MTL::Function> ASFn;
1835+
if (AS) {
1836+
if (auto Err = compileStage(Stages::Amplification, *AS, "amplification",
1837+
ASIR, ASLib, ASFn))
1838+
return Err;
1839+
}
1840+
1841+
MetalIR PSIR;
1842+
MTLPtr<MTL::Library> PSLib;
1843+
MTLPtr<MTL::Function> PSFn;
1844+
if (PS) {
1845+
if (auto Err =
1846+
compileStage(Stages::Pixel, *PS, "fragment", PSIR, PSLib, PSFn))
1847+
return Err;
1848+
}
1849+
1850+
MTL::MeshRenderPipelineDescriptor *Desc =
1851+
MTL::MeshRenderPipelineDescriptor::alloc()->init();
1852+
auto DescScope = llvm::scope_exit([&] { Desc->release(); });
1853+
1854+
Desc->setMeshFunction(MSFn.get());
1855+
if (ASFn)
1856+
Desc->setObjectFunction(ASFn.get());
1857+
if (PSFn)
1858+
Desc->setFragmentFunction(PSFn.get());
1859+
1860+
for (size_t I = 0; I < RTFormats.size(); ++I) {
1861+
MTL::RenderPipelineColorAttachmentDescriptor *RPCA =
1862+
MTL::RenderPipelineColorAttachmentDescriptor::alloc()->init();
1863+
RPCA->setPixelFormat(getMetalPixelFormat(RTFormats[I]));
1864+
Desc->colorAttachments()->setObject(RPCA, I);
1865+
RPCA->release();
1866+
}
1867+
1868+
if (DSFormat) {
1869+
const MTL::PixelFormat DSPixelFormat = getMetalPixelFormat(*DSFormat);
1870+
Desc->setDepthAttachmentPixelFormat(DSPixelFormat);
1871+
if (isStencilFormat(*DSFormat))
1872+
Desc->setStencilAttachmentPixelFormat(DSPixelFormat);
1873+
}
1874+
1875+
MTL::RenderPipelineState *PSO = Device->newRenderPipelineState(
1876+
Desc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
1877+
if (Error)
1878+
return toError(Error);
1879+
1880+
MTL::DepthStencilDescriptor *DSDesc =
1881+
MTL::DepthStencilDescriptor::alloc()->init();
1882+
DSDesc->setDepthCompareFunction(MTL::CompareFunctionLess);
1883+
DSDesc->setDepthWriteEnabled(true);
1884+
MTL::DepthStencilState *DSState = Device->newDepthStencilState(DSDesc);
1885+
DSDesc->release();
1886+
1887+
// Pull threads-per-threadgroup from shader reflection.
1888+
MTL::Size MeshTGSize(1, 1, 1);
1889+
{
1890+
IRVersionedMSInfo MSInfo;
1891+
if (IRShaderReflectionCopyMeshInfo(MSIR.Reflection.get(),
1892+
IRReflectionVersion_1_0, &MSInfo)) {
1893+
MeshTGSize = MTL::Size(MSInfo.info_1_0.num_threads[0],
1894+
MSInfo.info_1_0.num_threads[1],
1895+
MSInfo.info_1_0.num_threads[2]);
1896+
}
1897+
IRShaderReflectionReleaseMeshInfo(&MSInfo);
1898+
}
1899+
1900+
MTL::Size ObjectTGSize(1, 1, 1);
1901+
if (AS) {
1902+
IRVersionedASInfo ASInfo;
1903+
if (IRShaderReflectionCopyAmplificationInfo(
1904+
ASIR.Reflection.get(), IRReflectionVersion_1_0, &ASInfo)) {
1905+
ObjectTGSize = MTL::Size(ASInfo.info_1_0.num_threads[0],
1906+
ASInfo.info_1_0.num_threads[1],
1907+
ASInfo.info_1_0.num_threads[2]);
1908+
}
1909+
IRShaderReflectionReleaseAmplificationInfo(&ASInfo);
1910+
}
1911+
1912+
return std::make_unique<MTLPipelineState>(
1913+
Name, std::move(RootSig), std::move(ArgBuffer), PSO, DSState,
1914+
MTL::CullModeNone, MeshTGSize, ObjectTGSize);
1915+
}
1916+
17521917
llvm::Error executeProgram(Pipeline &P) override {
17531918
InvocationState IS;
17541919

@@ -1797,32 +1962,7 @@ class MTLDevice : public offloadtest::Device {
17971962

17981963
if (auto Err = createComputeCommands(P, IS))
17991964
return Err;
1800-
} else {
1801-
ShaderContainer VS = {};
1802-
ShaderContainer PS = {};
1803-
for (auto &Shader : P.Shaders) {
1804-
if (Shader.Stage == Stages::Vertex) {
1805-
VS.EntryPoint = Shader.Entry;
1806-
VS.Shader = Shader.Shader.get();
1807-
} else if (Shader.Stage == Stages::Pixel) {
1808-
PS.EntryPoint = Shader.Entry;
1809-
PS.Shader = Shader.Shader.get();
1810-
}
1811-
}
1812-
1813-
llvm::SmallVector<InputLayoutDesc> InputLayout;
1814-
for (auto &Attr : P.Bindings.VertexAttributes) {
1815-
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
1816-
if (!FormatOrErr)
1817-
return FormatOrErr.takeError();
1818-
1819-
InputLayoutDesc Desc = {};
1820-
Desc.Name = Attr.Name;
1821-
Desc.Fmt = *FormatOrErr;
1822-
Desc.OffsetInBytes = Attr.Offset;
1823-
InputLayout.push_back(Desc);
1824-
}
1825-
1965+
} else if (P.isRaster()) {
18261966
auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format,
18271967
P.Bindings.RTargetBufferPtr->Channels);
18281968
if (!FormatOrErr)
@@ -1831,12 +1971,63 @@ class MTLDevice : public offloadtest::Device {
18311971
llvm::SmallVector<Format> RTFormats;
18321972
RTFormats.push_back(*FormatOrErr);
18331973

1834-
auto PipelineStateOrErr =
1835-
createPipelineVsPs("Graphics Pipeline State", Bindings, InputLayout,
1836-
RTFormats, Format::D32FloatS8Uint, VS, PS);
1837-
if (!PipelineStateOrErr)
1838-
return PipelineStateOrErr.takeError();
1839-
IS.Pipeline = std::move(*PipelineStateOrErr);
1974+
if (P.isTraditionalRaster()) {
1975+
ShaderContainer VS = {};
1976+
ShaderContainer PS = {};
1977+
for (auto &Shader : P.Shaders) {
1978+
if (Shader.Stage == Stages::Vertex) {
1979+
VS.EntryPoint = Shader.Entry;
1980+
VS.Shader = Shader.Shader.get();
1981+
} else if (Shader.Stage == Stages::Pixel) {
1982+
PS.EntryPoint = Shader.Entry;
1983+
PS.Shader = Shader.Shader.get();
1984+
}
1985+
}
1986+
1987+
// Create the input layout based on the vertex attributes.
1988+
llvm::SmallVector<InputLayoutDesc> InputLayout;
1989+
for (auto &Attr : P.Bindings.VertexAttributes) {
1990+
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
1991+
if (!FormatOrErr)
1992+
return FormatOrErr.takeError();
1993+
1994+
InputLayoutDesc Desc = {};
1995+
Desc.Name = Attr.Name;
1996+
Desc.Fmt = *FormatOrErr;
1997+
Desc.OffsetInBytes = Attr.Offset;
1998+
InputLayout.push_back(Desc);
1999+
}
2000+
2001+
auto PipelineStateOrErr =
2002+
createPipelineVsPs("Graphics Pipeline State", Bindings, InputLayout,
2003+
RTFormats, Format::D32FloatS8Uint, VS, PS);
2004+
if (!PipelineStateOrErr)
2005+
return PipelineStateOrErr.takeError();
2006+
IS.Pipeline = std::move(*PipelineStateOrErr);
2007+
} else if (P.isMeshShaderRaster()) {
2008+
std::optional<ShaderContainer> AS;
2009+
ShaderContainer MS = {};
2010+
std::optional<ShaderContainer> PS;
2011+
for (auto &Shader : P.Shaders) {
2012+
ShaderContainer SC = {};
2013+
SC.EntryPoint = Shader.Entry;
2014+
SC.Shader = Shader.Shader.get();
2015+
if (Shader.Stage == Stages::Amplification)
2016+
AS = std::move(SC);
2017+
else if (Shader.Stage == Stages::Mesh)
2018+
MS = std::move(SC);
2019+
else if (Shader.Stage == Stages::Pixel)
2020+
PS = std::move(SC);
2021+
}
2022+
2023+
auto PipelineStateOrErr =
2024+
createPipelineAsMsPs("Mesh Shader Pipeline State", Bindings,
2025+
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
2026+
if (!PipelineStateOrErr)
2027+
return PipelineStateOrErr.takeError();
2028+
IS.Pipeline = std::move(*PipelineStateOrErr);
2029+
llvm::outs() << "Mesh Shader Pipeline created.\n";
2030+
}
18402031

18412032
ColorAttachmentFormatDesc ColorAttachment = {};
18422033
ColorAttachment.Fmt = *FormatOrErr;
@@ -1878,7 +2069,13 @@ class MTLDevice : public offloadtest::Device {
18782069
virtual ~MTLDevice() {};
18792070

18802071
private:
1881-
void queryCapabilities() {}
2072+
void queryCapabilities() {
2073+
// GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
2074+
const bool MeshShaderSupported =
2075+
Device->supportsFamily(MTL::GPUFamilyMetal3);
2076+
Caps.insert(std::make_pair(
2077+
"MeshShader", makeCapability<bool>("MeshShader", MeshShaderSupported)));
2078+
}
18822079
};
18832080
} // namespace
18842081

test/lit.cfg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def setDeviceFeatures(config, device, compiler):
158158
config.available_features.add("Int16")
159159
config.available_features.add("Int64")
160160
config.available_features.add("Half")
161+
if device["Features"].get("MeshShader", False):
162+
config.available_features.add("MeshShader")
161163

162164
if device["API"] == "Vulkan":
163165
if device["Features"].get("shaderInt16", False):

0 commit comments

Comments
 (0)