Skip to content

Commit bb0e8dd

Browse files
Add Mesh Shader support to Metal.
1 parent 7cd5ae1 commit bb0e8dd

2 files changed

Lines changed: 244 additions & 46 deletions

File tree

lib/API/MTL/MTLDevice.cpp

Lines changed: 242 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
248248
MTL::DepthStencilState *DepthStencilState = nullptr;
249249
MTL::CullMode CullMode = MTL::CullModeNone;
250250

251+
MTL::Size MeshThreadsPerThreadgroup{1, 1, 1};
252+
MTL::Size ObjectThreadsPerThreadgroup{1, 1, 1};
253+
251254
MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
252255
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
253256
IRShaderReflectionPtr Reflection,
@@ -260,11 +263,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
260263
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
261264
MTL::RenderPipelineState *RenderPipeline,
262265
MTL::DepthStencilState *DepthStencilState,
263-
MTL::CullMode CullMode)
266+
MTL::CullMode CullMode,
267+
MTL::Size MeshThreadsPerThreadgroup = {1, 1, 1},
268+
MTL::Size ObjectThreadsPerThreadgroup = {1, 1, 1})
264269
: offloadtest::PipelineState(GPUAPI::Metal), Name(Name),
265270
RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
266271
RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
267-
CullMode(CullMode) {}
272+
CullMode(CullMode),
273+
MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
274+
ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}
268275

269276
~MTLPipelineState() override {
270277
if (ComputePipeline)
@@ -708,13 +715,29 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
708715
llvm::Error dispatchMesh(const offloadtest::PipelineState &PSO,
709716
uint32_t GroupCountX, uint32_t GroupCountY,
710717
uint32_t GroupCountZ) override {
711-
(void)PSO;
712-
(void)GroupCountX;
713-
(void)GroupCountY;
714-
(void)GroupCountZ;
718+
if (!ViewportSet)
719+
return llvm::createStringError(std::errc::invalid_argument,
720+
"Viewport must be set before drawing.");
721+
if (!ScissorSet)
722+
return llvm::createStringError(std::errc::invalid_argument,
723+
"Scissor must be set before drawing.");
715724

716-
return llvm::createStringError(
717-
"dispatchMesh is unimplemented in the Metal backend.");
725+
const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
726+
if (!MTLPSO.RenderPipeline)
727+
return llvm::createStringError(
728+
std::errc::invalid_argument,
729+
"PipelineState bound to dispatchMesh() is not a render pipeline.");
730+
RenderEnc->setRenderPipelineState(MTLPSO.RenderPipeline);
731+
if (MTLPSO.DepthStencilState)
732+
RenderEnc->setDepthStencilState(MTLPSO.DepthStencilState);
733+
RenderEnc->setCullMode(MTLPSO.CullMode);
734+
RenderEnc->setFrontFacingWinding(MTL::WindingCounterClockwise);
735+
736+
RenderEnc->drawMeshThreadgroups(
737+
MTL::Size(GroupCountX, GroupCountY, GroupCountZ),
738+
MTLPSO.ObjectThreadsPerThreadgroup, MTLPSO.MeshThreadsPerThreadgroup);
739+
740+
return llvm::Error::success();
718741
}
719742

720743
void endEncodingImpl() override {
@@ -1413,12 +1436,23 @@ class MTLDevice : public offloadtest::Device {
14131436
Scissor.Height = static_cast<uint32_t>(Height);
14141437
Encoder.setScissor(Scissor);
14151438

1416-
if (IS.VB)
1417-
Encoder.setVertexBuffer(0, IS.VB.get(), 0, P.Bindings.getVertexStride());
1439+
if (P.isTraditionalRaster()) {
1440+
if (IS.VB)
1441+
Encoder.setVertexBuffer(0, IS.VB.get(), 0,
1442+
P.Bindings.getVertexStride());
1443+
1444+
if (auto Err =
1445+
Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
1446+
/*InstanceCount=*/1))
1447+
return Err;
1448+
} else {
1449+
if (auto Err = Encoder.dispatchMesh(
1450+
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
1451+
P.DispatchParameters.DispatchGroupCount[1],
1452+
P.DispatchParameters.DispatchGroupCount[2]))
1453+
return Err;
1454+
}
14181455

1419-
if (auto Err = Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(),
1420-
/*InstanceCount=*/1))
1421-
return Err;
14221456
Encoder.endEncoding();
14231457

14241458
// Blit the render target into the readback buffer for CPU access.
@@ -1749,6 +1783,136 @@ class MTLDevice : public offloadtest::Device {
17491783
DSState, MTL::CullModeNone);
17501784
}
17511785

1786+
llvm::Expected<std::unique_ptr<PipelineState>>
1787+
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1788+
llvm::ArrayRef<Format> RTFormats,
1789+
std::optional<Format> DSFormat,
1790+
std::optional<ShaderContainer> AS, ShaderContainer MS,
1791+
std::optional<ShaderContainer> PS) {
1792+
IRRootSignaturePtr RootSig;
1793+
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
1794+
if (auto Err = createRootSignature(BindingsDesc, /*IsGraphics=*/true,
1795+
RootSig, ArgBuffer))
1796+
return Err;
1797+
1798+
NS::Error *Error = nullptr;
1799+
auto compileStage = [&](Stages Stage, const ShaderContainer &SC,
1800+
llvm::StringRef RoleName, MetalIR &OutIR,
1801+
MTLPtr<MTL::Library> &OutLib,
1802+
MTLPtr<MTL::Function> &OutFn) -> llvm::Error {
1803+
auto IROrErr =
1804+
convertToMetalIR(Stage, /*IsGraphics=*/true, RootSig.get(), SC);
1805+
if (!IROrErr)
1806+
return IROrErr.takeError();
1807+
OutIR = std::move(*IROrErr);
1808+
1809+
dispatch_data_t Data = IRMetalLibGetBytecodeData(OutIR.Binary.get());
1810+
NS::Error *Err = nullptr;
1811+
OutLib = MTLPtr<MTL::Library>(Device->newLibrary(Data, &Err));
1812+
if (Err)
1813+
return toError(Err);
1814+
1815+
OutFn = MTLPtr<MTL::Function>(OutLib->newFunction(
1816+
NS::String::string(SC.EntryPoint.c_str(), NS::UTF8StringEncoding)));
1817+
if (!OutFn)
1818+
return llvm::createStringError(
1819+
std::errc::invalid_argument,
1820+
"Failed to find %s entry point '%s' in Metal library.",
1821+
RoleName.data(), SC.EntryPoint.c_str());
1822+
return llvm::Error::success();
1823+
};
1824+
1825+
MetalIR MSIR;
1826+
MTLPtr<MTL::Library> MSLib;
1827+
MTLPtr<MTL::Function> MSFn;
1828+
if (auto Err = compileStage(Stages::Mesh, MS, "mesh", MSIR, MSLib, MSFn))
1829+
return Err;
1830+
1831+
MetalIR ASIR;
1832+
MTLPtr<MTL::Library> ASLib;
1833+
MTLPtr<MTL::Function> ASFn;
1834+
if (AS) {
1835+
if (auto Err = compileStage(Stages::Amplification, *AS, "amplification",
1836+
ASIR, ASLib, ASFn))
1837+
return Err;
1838+
}
1839+
1840+
MetalIR PSIR;
1841+
MTLPtr<MTL::Library> PSLib;
1842+
MTLPtr<MTL::Function> PSFn;
1843+
if (PS) {
1844+
if (auto Err =
1845+
compileStage(Stages::Pixel, *PS, "fragment", PSIR, PSLib, PSFn))
1846+
return Err;
1847+
}
1848+
1849+
MTL::MeshRenderPipelineDescriptor *Desc =
1850+
MTL::MeshRenderPipelineDescriptor::alloc()->init();
1851+
auto DescScope = llvm::scope_exit([&] { Desc->release(); });
1852+
1853+
Desc->setMeshFunction(MSFn.get());
1854+
if (ASFn)
1855+
Desc->setObjectFunction(ASFn.get());
1856+
if (PSFn)
1857+
Desc->setFragmentFunction(PSFn.get());
1858+
1859+
for (size_t I = 0; I < RTFormats.size(); ++I) {
1860+
MTL::RenderPipelineColorAttachmentDescriptor *RPCA =
1861+
MTL::RenderPipelineColorAttachmentDescriptor::alloc()->init();
1862+
RPCA->setPixelFormat(getMetalPixelFormat(RTFormats[I]));
1863+
Desc->colorAttachments()->setObject(RPCA, I);
1864+
RPCA->release();
1865+
}
1866+
1867+
if (DSFormat) {
1868+
const MTL::PixelFormat DSPixelFormat = getMetalPixelFormat(*DSFormat);
1869+
Desc->setDepthAttachmentPixelFormat(DSPixelFormat);
1870+
if (isStencilFormat(*DSFormat))
1871+
Desc->setStencilAttachmentPixelFormat(DSPixelFormat);
1872+
}
1873+
1874+
MTL::RenderPipelineState *PSO = Device->newRenderPipelineState(
1875+
Desc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
1876+
if (Error)
1877+
return toError(Error);
1878+
1879+
MTL::DepthStencilDescriptor *DSDesc =
1880+
MTL::DepthStencilDescriptor::alloc()->init();
1881+
DSDesc->setDepthCompareFunction(MTL::CompareFunctionLess);
1882+
DSDesc->setDepthWriteEnabled(true);
1883+
MTL::DepthStencilState *DSState = Device->newDepthStencilState(DSDesc);
1884+
DSDesc->release();
1885+
1886+
// Pull threads-per-threadgroup from shader reflection.
1887+
MTL::Size MeshTGSize(1, 1, 1);
1888+
{
1889+
IRVersionedMSInfo MSInfo;
1890+
if (IRShaderReflectionCopyMeshInfo(MSIR.Reflection.get(),
1891+
IRReflectionVersion_1_0, &MSInfo)) {
1892+
MeshTGSize = MTL::Size(MSInfo.info_1_0.num_threads[0],
1893+
MSInfo.info_1_0.num_threads[1],
1894+
MSInfo.info_1_0.num_threads[2]);
1895+
}
1896+
IRShaderReflectionReleaseMeshInfo(&MSInfo);
1897+
}
1898+
1899+
MTL::Size ObjectTGSize(1, 1, 1);
1900+
if (AS) {
1901+
IRVersionedASInfo ASInfo;
1902+
if (IRShaderReflectionCopyAmplificationInfo(
1903+
ASIR.Reflection.get(), IRReflectionVersion_1_0, &ASInfo)) {
1904+
ObjectTGSize = MTL::Size(ASInfo.info_1_0.num_threads[0],
1905+
ASInfo.info_1_0.num_threads[1],
1906+
ASInfo.info_1_0.num_threads[2]);
1907+
}
1908+
IRShaderReflectionReleaseAmplificationInfo(&ASInfo);
1909+
}
1910+
1911+
return std::make_unique<MTLPipelineState>(
1912+
Name, std::move(RootSig), std::move(ArgBuffer), PSO, DSState,
1913+
MTL::CullModeNone, MeshTGSize, ObjectTGSize);
1914+
}
1915+
17521916
llvm::Error executeProgram(Pipeline &P) override {
17531917
InvocationState IS;
17541918

@@ -1797,32 +1961,7 @@ class MTLDevice : public offloadtest::Device {
17971961

17981962
if (auto Err = createComputeCommands(P, IS))
17991963
return Err;
1800-
} else {
1801-
ShaderContainer VS = {};
1802-
ShaderContainer PS = {};
1803-
for (auto &Shader : P.Shaders) {
1804-
if (Shader.Stage == Stages::Vertex) {
1805-
VS.EntryPoint = Shader.Entry;
1806-
VS.Shader = Shader.Shader.get();
1807-
} else if (Shader.Stage == Stages::Pixel) {
1808-
PS.EntryPoint = Shader.Entry;
1809-
PS.Shader = Shader.Shader.get();
1810-
}
1811-
}
1812-
1813-
llvm::SmallVector<InputLayoutDesc> InputLayout;
1814-
for (auto &Attr : P.Bindings.VertexAttributes) {
1815-
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
1816-
if (!FormatOrErr)
1817-
return FormatOrErr.takeError();
1818-
1819-
InputLayoutDesc Desc = {};
1820-
Desc.Name = Attr.Name;
1821-
Desc.Fmt = *FormatOrErr;
1822-
Desc.OffsetInBytes = Attr.Offset;
1823-
InputLayout.push_back(Desc);
1824-
}
1825-
1964+
} else if (P.isRaster()) {
18261965
auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format,
18271966
P.Bindings.RTargetBufferPtr->Channels);
18281967
if (!FormatOrErr)
@@ -1831,12 +1970,63 @@ class MTLDevice : public offloadtest::Device {
18311970
llvm::SmallVector<Format> RTFormats;
18321971
RTFormats.push_back(*FormatOrErr);
18331972

1834-
auto PipelineStateOrErr =
1835-
createPipelineVsPs("Graphics Pipeline State", Bindings, InputLayout,
1836-
RTFormats, Format::D32FloatS8Uint, VS, PS);
1837-
if (!PipelineStateOrErr)
1838-
return PipelineStateOrErr.takeError();
1839-
IS.Pipeline = std::move(*PipelineStateOrErr);
1973+
if (P.isTraditionalRaster()) {
1974+
ShaderContainer VS = {};
1975+
ShaderContainer PS = {};
1976+
for (auto &Shader : P.Shaders) {
1977+
if (Shader.Stage == Stages::Vertex) {
1978+
VS.EntryPoint = Shader.Entry;
1979+
VS.Shader = Shader.Shader.get();
1980+
} else if (Shader.Stage == Stages::Pixel) {
1981+
PS.EntryPoint = Shader.Entry;
1982+
PS.Shader = Shader.Shader.get();
1983+
}
1984+
}
1985+
1986+
// Create the input layout based on the vertex attributes.
1987+
llvm::SmallVector<InputLayoutDesc> InputLayout;
1988+
for (auto &Attr : P.Bindings.VertexAttributes) {
1989+
auto FormatOrErr = toFormat(Attr.Format, Attr.Channels);
1990+
if (!FormatOrErr)
1991+
return FormatOrErr.takeError();
1992+
1993+
InputLayoutDesc Desc = {};
1994+
Desc.Name = Attr.Name;
1995+
Desc.Fmt = *FormatOrErr;
1996+
Desc.OffsetInBytes = Attr.Offset;
1997+
InputLayout.push_back(Desc);
1998+
}
1999+
2000+
auto PipelineStateOrErr =
2001+
createPipelineVsPs("Graphics Pipeline State", Bindings, InputLayout,
2002+
RTFormats, Format::D32FloatS8Uint, VS, PS);
2003+
if (!PipelineStateOrErr)
2004+
return PipelineStateOrErr.takeError();
2005+
IS.Pipeline = std::move(*PipelineStateOrErr);
2006+
} else if (P.isMeshShaderRaster()) {
2007+
std::optional<ShaderContainer> AS;
2008+
ShaderContainer MS = {};
2009+
std::optional<ShaderContainer> PS;
2010+
for (auto &Shader : P.Shaders) {
2011+
ShaderContainer SC = {};
2012+
SC.EntryPoint = Shader.Entry;
2013+
SC.Shader = Shader.Shader.get();
2014+
if (Shader.Stage == Stages::Amplification)
2015+
AS = std::move(SC);
2016+
else if (Shader.Stage == Stages::Mesh)
2017+
MS = std::move(SC);
2018+
else if (Shader.Stage == Stages::Pixel)
2019+
PS = std::move(SC);
2020+
}
2021+
2022+
auto PipelineStateOrErr =
2023+
createPipelineAsMsPs("Mesh Shader Pipeline State", Bindings,
2024+
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
2025+
if (!PipelineStateOrErr)
2026+
return PipelineStateOrErr.takeError();
2027+
IS.Pipeline = std::move(*PipelineStateOrErr);
2028+
llvm::outs() << "Mesh Shader Pipeline created.\n";
2029+
}
18402030

18412031
ColorAttachmentFormatDesc ColorAttachment = {};
18422032
ColorAttachment.Fmt = *FormatOrErr;
@@ -1878,7 +2068,13 @@ class MTLDevice : public offloadtest::Device {
18782068
virtual ~MTLDevice() {};
18792069

18802070
private:
1881-
void queryCapabilities() {}
2071+
void queryCapabilities() {
2072+
// GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
2073+
const bool MeshShaderSupported =
2074+
Device->supportsFamily(MTL::GPUFamilyMetal3);
2075+
Caps.insert(std::make_pair(
2076+
"MeshShader", makeCapability<bool>("MeshShader", MeshShaderSupported)));
2077+
}
18822078
};
18832079
} // namespace
18842080

test/lit.cfg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def setDeviceFeatures(config, device, compiler):
158158
config.available_features.add("Int16")
159159
config.available_features.add("Int64")
160160
config.available_features.add("Half")
161+
if device["Features"].get("MeshShader", False):
162+
config.available_features.add("MeshShader")
161163

162164
if device["API"] == "Vulkan":
163165
if device["Features"].get("shaderInt16", False):

0 commit comments

Comments
 (0)