@@ -248,6 +248,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
248248 MTL ::DepthStencilState *DepthStencilState = nullptr ;
249249 MTL ::CullMode CullMode = MTL ::CullModeNone;
250250
251+ MTL ::Size MeshThreadsPerThreadgroup{1 , 1 , 1 };
252+ MTL ::Size ObjectThreadsPerThreadgroup{1 , 1 , 1 };
253+
251254 MTLPipelineState (llvm::StringRef Name, IRRootSignaturePtr RootSig,
252255 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
253256 IRShaderReflectionPtr Reflection,
@@ -260,11 +263,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
260263 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
261264 MTL ::RenderPipelineState *RenderPipeline,
262265 MTL ::DepthStencilState *DepthStencilState,
263- MTL ::CullMode CullMode)
266+ MTL ::CullMode CullMode,
267+ MTL ::Size MeshThreadsPerThreadgroup = {1 , 1 , 1 },
268+ MTL ::Size ObjectThreadsPerThreadgroup = {1 , 1 , 1 })
264269 : offloadtest::PipelineState(GPUAPI ::Metal), Name(Name),
265270 RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
266271 RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
267- CullMode(CullMode) {}
272+ CullMode(CullMode),
273+ MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
274+ ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}
268275
269276 ~MTLPipelineState () override {
270277 if (ComputePipeline)
@@ -708,13 +715,29 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
708715 llvm::Error dispatchMesh (const offloadtest::PipelineState &PSO ,
709716 uint32_t GroupCountX, uint32_t GroupCountY,
710717 uint32_t GroupCountZ) override {
711- (void )PSO ;
712- (void )GroupCountX;
713- (void )GroupCountY;
714- (void )GroupCountZ;
718+ if (!ViewportSet)
719+ return llvm::createStringError (std::errc::invalid_argument,
720+ " Viewport must be set before drawing." );
721+ if (!ScissorSet)
722+ return llvm::createStringError (std::errc::invalid_argument,
723+ " Scissor must be set before drawing." );
715724
716- return llvm::createStringError (
717- " dispatchMesh is unimplemented in the Metal backend." );
725+ const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO );
726+ if (!MTLPSO .RenderPipeline )
727+ return llvm::createStringError (
728+ std::errc::invalid_argument,
729+ " PipelineState bound to dispatchMesh() is not a render pipeline." );
730+ RenderEnc->setRenderPipelineState (MTLPSO .RenderPipeline );
731+ if (MTLPSO .DepthStencilState )
732+ RenderEnc->setDepthStencilState (MTLPSO .DepthStencilState );
733+ RenderEnc->setCullMode (MTLPSO .CullMode );
734+ RenderEnc->setFrontFacingWinding (MTL ::WindingCounterClockwise);
735+
736+ RenderEnc->drawMeshThreadgroups (
737+ MTL::Size (GroupCountX, GroupCountY, GroupCountZ),
738+ MTLPSO .ObjectThreadsPerThreadgroup , MTLPSO .MeshThreadsPerThreadgroup );
739+
740+ return llvm::Error::success ();
718741 }
719742
720743 void endEncodingImpl () override {
@@ -1413,12 +1436,23 @@ class MTLDevice : public offloadtest::Device {
14131436 Scissor.Height = static_cast <uint32_t >(Height);
14141437 Encoder.setScissor (Scissor);
14151438
1416- if (IS .VB )
1417- Encoder.setVertexBuffer (0 , IS .VB .get (), 0 , P.Bindings .getVertexStride ());
1439+ if (P.isTraditionalRaster ()) {
1440+ if (IS .VB )
1441+ Encoder.setVertexBuffer (0 , IS .VB .get (), 0 ,
1442+ P.Bindings .getVertexStride ());
1443+
1444+ if (auto Err =
1445+ Encoder.drawInstanced (*IS .Pipeline .get (), P.getVertexCount (),
1446+ /* InstanceCount=*/ 1 ))
1447+ return Err;
1448+ } else {
1449+ if (auto Err = Encoder.dispatchMesh (
1450+ *IS .Pipeline .get (), P.DispatchParameters .DispatchGroupCount [0 ],
1451+ P.DispatchParameters .DispatchGroupCount [1 ],
1452+ P.DispatchParameters .DispatchGroupCount [2 ]))
1453+ return Err;
1454+ }
14181455
1419- if (auto Err = Encoder.drawInstanced (*IS .Pipeline .get (), P.getVertexCount (),
1420- /* InstanceCount=*/ 1 ))
1421- return Err;
14221456 Encoder.endEncoding ();
14231457
14241458 // Blit the render target into the readback buffer for CPU access.
@@ -1749,6 +1783,136 @@ class MTLDevice : public offloadtest::Device {
17491783 DSState, MTL ::CullModeNone);
17501784 }
17511785
1786+ llvm::Expected<std::unique_ptr<PipelineState>>
1787+ createPipelineAsMsPs (llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1788+ llvm::ArrayRef<Format> RTFormats,
1789+ std::optional<Format> DSFormat,
1790+ std::optional<ShaderContainer> AS , ShaderContainer MS ,
1791+ std::optional<ShaderContainer> PS ) {
1792+ IRRootSignaturePtr RootSig;
1793+ std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
1794+ if (auto Err = createRootSignature (BindingsDesc, /* IsGraphics=*/ true ,
1795+ RootSig, ArgBuffer))
1796+ return Err;
1797+
1798+ NS ::Error *Error = nullptr ;
1799+ auto compileStage = [&](Stages Stage, const ShaderContainer &SC ,
1800+ llvm::StringRef RoleName, MetalIR &OutIR,
1801+ MTLPtr<MTL ::Library> &OutLib,
1802+ MTLPtr<MTL ::Function> &OutFn) -> llvm::Error {
1803+ auto IROrErr =
1804+ convertToMetalIR (Stage, /* IsGraphics=*/ true , RootSig.get (), SC );
1805+ if (!IROrErr)
1806+ return IROrErr.takeError ();
1807+ OutIR = std::move (*IROrErr);
1808+
1809+ dispatch_data_t Data = IRMetalLibGetBytecodeData (OutIR.Binary .get ());
1810+ NS ::Error *Err = nullptr ;
1811+ OutLib = MTLPtr<MTL ::Library>(Device->newLibrary (Data, &Err));
1812+ if (Err)
1813+ return toError (Err);
1814+
1815+ OutFn = MTLPtr<MTL ::Function>(OutLib->newFunction (
1816+ NS::String::string (SC .EntryPoint.c_str(), NS::UTF8StringEncoding)));
1817+ if (!OutFn)
1818+ return llvm::createStringError (
1819+ std::errc::invalid_argument,
1820+ " Failed to find %s entry point '%s' in Metal library." ,
1821+ RoleName.data (), SC .EntryPoint .c_str ());
1822+ return llvm::Error::success ();
1823+ };
1824+
1825+ MetalIR MSIR ;
1826+ MTLPtr<MTL ::Library> MSLib;
1827+ MTLPtr<MTL ::Function> MSFn;
1828+ if (auto Err = compileStage (Stages::Mesh, MS , " mesh" , MSIR , MSLib, MSFn))
1829+ return Err;
1830+
1831+ MetalIR ASIR ;
1832+ MTLPtr<MTL ::Library> ASLib;
1833+ MTLPtr<MTL ::Function> ASFn;
1834+ if (AS ) {
1835+ if (auto Err = compileStage (Stages::Amplification, *AS , " amplification" ,
1836+ ASIR , ASLib, ASFn))
1837+ return Err;
1838+ }
1839+
1840+ MetalIR PSIR ;
1841+ MTLPtr<MTL ::Library> PSLib;
1842+ MTLPtr<MTL ::Function> PSFn;
1843+ if (PS ) {
1844+ if (auto Err =
1845+ compileStage (Stages::Pixel, *PS , " fragment" , PSIR , PSLib, PSFn))
1846+ return Err;
1847+ }
1848+
1849+ MTL ::MeshRenderPipelineDescriptor *Desc =
1850+ MTL::MeshRenderPipelineDescriptor::alloc ()->init ();
1851+ auto DescScope = llvm::scope_exit ([&] { Desc->release (); });
1852+
1853+ Desc->setMeshFunction (MSFn.get ());
1854+ if (ASFn)
1855+ Desc->setObjectFunction (ASFn.get ());
1856+ if (PSFn)
1857+ Desc->setFragmentFunction (PSFn.get ());
1858+
1859+ for (size_t I = 0 ; I < RTFormats.size (); ++I) {
1860+ MTL ::RenderPipelineColorAttachmentDescriptor *RPCA =
1861+ MTL::RenderPipelineColorAttachmentDescriptor::alloc ()->init ();
1862+ RPCA ->setPixelFormat (getMetalPixelFormat (RTFormats[I]));
1863+ Desc->colorAttachments ()->setObject (RPCA , I);
1864+ RPCA ->release ();
1865+ }
1866+
1867+ if (DSFormat) {
1868+ const MTL ::PixelFormat DSPixelFormat = getMetalPixelFormat (*DSFormat);
1869+ Desc->setDepthAttachmentPixelFormat (DSPixelFormat);
1870+ if (isStencilFormat (*DSFormat))
1871+ Desc->setStencilAttachmentPixelFormat (DSPixelFormat);
1872+ }
1873+
1874+ MTL ::RenderPipelineState *PSO = Device->newRenderPipelineState (
1875+ Desc, MTL ::PipelineOptionNone, /* reflection=*/ nullptr , &Error);
1876+ if (Error)
1877+ return toError (Error);
1878+
1879+ MTL ::DepthStencilDescriptor *DSDesc =
1880+ MTL::DepthStencilDescriptor::alloc ()->init ();
1881+ DSDesc->setDepthCompareFunction (MTL ::CompareFunctionLess);
1882+ DSDesc->setDepthWriteEnabled (true );
1883+ MTL ::DepthStencilState *DSState = Device->newDepthStencilState (DSDesc);
1884+ DSDesc->release ();
1885+
1886+ // Pull threads-per-threadgroup from shader reflection.
1887+ MTL ::Size MeshTGSize (1 , 1 , 1 );
1888+ {
1889+ IRVersionedMSInfo MSInfo;
1890+ if (IRShaderReflectionCopyMeshInfo (MSIR .Reflection .get (),
1891+ IRReflectionVersion_1_0, &MSInfo)) {
1892+ MeshTGSize = MTL::Size (MSInfo.info_1_0 .num_threads [0 ],
1893+ MSInfo.info_1_0 .num_threads [1 ],
1894+ MSInfo.info_1_0 .num_threads [2 ]);
1895+ }
1896+ IRShaderReflectionReleaseMeshInfo (&MSInfo);
1897+ }
1898+
1899+ MTL ::Size ObjectTGSize (1 , 1 , 1 );
1900+ if (AS ) {
1901+ IRVersionedASInfo ASInfo;
1902+ if (IRShaderReflectionCopyAmplificationInfo (
1903+ ASIR .Reflection .get (), IRReflectionVersion_1_0, &ASInfo)) {
1904+ ObjectTGSize = MTL::Size (ASInfo.info_1_0 .num_threads [0 ],
1905+ ASInfo.info_1_0 .num_threads [1 ],
1906+ ASInfo.info_1_0 .num_threads [2 ]);
1907+ }
1908+ IRShaderReflectionReleaseAmplificationInfo (&ASInfo);
1909+ }
1910+
1911+ return std::make_unique<MTLPipelineState>(
1912+ Name, std::move (RootSig), std::move (ArgBuffer), PSO , DSState,
1913+ MTL ::CullModeNone, MeshTGSize, ObjectTGSize);
1914+ }
1915+
17521916 llvm::Error executeProgram (Pipeline &P) override {
17531917 InvocationState IS ;
17541918
@@ -1797,32 +1961,7 @@ class MTLDevice : public offloadtest::Device {
17971961
17981962 if (auto Err = createComputeCommands (P, IS ))
17991963 return Err;
1800- } else {
1801- ShaderContainer VS = {};
1802- ShaderContainer PS = {};
1803- for (auto &Shader : P.Shaders ) {
1804- if (Shader.Stage == Stages::Vertex) {
1805- VS .EntryPoint = Shader.Entry ;
1806- VS .Shader = Shader.Shader .get ();
1807- } else if (Shader.Stage == Stages::Pixel) {
1808- PS .EntryPoint = Shader.Entry ;
1809- PS .Shader = Shader.Shader .get ();
1810- }
1811- }
1812-
1813- llvm::SmallVector<InputLayoutDesc> InputLayout;
1814- for (auto &Attr : P.Bindings .VertexAttributes ) {
1815- auto FormatOrErr = toFormat (Attr.Format , Attr.Channels );
1816- if (!FormatOrErr)
1817- return FormatOrErr.takeError ();
1818-
1819- InputLayoutDesc Desc = {};
1820- Desc.Name = Attr.Name ;
1821- Desc.Fmt = *FormatOrErr;
1822- Desc.OffsetInBytes = Attr.Offset ;
1823- InputLayout.push_back (Desc);
1824- }
1825-
1964+ } else if (P.isRaster ()) {
18261965 auto FormatOrErr = toFormat (P.Bindings .RTargetBufferPtr ->Format ,
18271966 P.Bindings .RTargetBufferPtr ->Channels );
18281967 if (!FormatOrErr)
@@ -1831,12 +1970,63 @@ class MTLDevice : public offloadtest::Device {
18311970 llvm::SmallVector<Format> RTFormats;
18321971 RTFormats.push_back (*FormatOrErr);
18331972
1834- auto PipelineStateOrErr =
1835- createPipelineVsPs (" Graphics Pipeline State" , Bindings, InputLayout,
1836- RTFormats, Format::D32FloatS8Uint, VS , PS );
1837- if (!PipelineStateOrErr)
1838- return PipelineStateOrErr.takeError ();
1839- IS .Pipeline = std::move (*PipelineStateOrErr);
1973+ if (P.isTraditionalRaster ()) {
1974+ ShaderContainer VS = {};
1975+ ShaderContainer PS = {};
1976+ for (auto &Shader : P.Shaders ) {
1977+ if (Shader.Stage == Stages::Vertex) {
1978+ VS .EntryPoint = Shader.Entry ;
1979+ VS .Shader = Shader.Shader .get ();
1980+ } else if (Shader.Stage == Stages::Pixel) {
1981+ PS .EntryPoint = Shader.Entry ;
1982+ PS .Shader = Shader.Shader .get ();
1983+ }
1984+ }
1985+
1986+ // Create the input layout based on the vertex attributes.
1987+ llvm::SmallVector<InputLayoutDesc> InputLayout;
1988+ for (auto &Attr : P.Bindings .VertexAttributes ) {
1989+ auto FormatOrErr = toFormat (Attr.Format , Attr.Channels );
1990+ if (!FormatOrErr)
1991+ return FormatOrErr.takeError ();
1992+
1993+ InputLayoutDesc Desc = {};
1994+ Desc.Name = Attr.Name ;
1995+ Desc.Fmt = *FormatOrErr;
1996+ Desc.OffsetInBytes = Attr.Offset ;
1997+ InputLayout.push_back (Desc);
1998+ }
1999+
2000+ auto PipelineStateOrErr =
2001+ createPipelineVsPs (" Graphics Pipeline State" , Bindings, InputLayout,
2002+ RTFormats, Format::D32FloatS8Uint, VS , PS );
2003+ if (!PipelineStateOrErr)
2004+ return PipelineStateOrErr.takeError ();
2005+ IS .Pipeline = std::move (*PipelineStateOrErr);
2006+ } else if (P.isMeshShaderRaster ()) {
2007+ std::optional<ShaderContainer> AS ;
2008+ ShaderContainer MS = {};
2009+ std::optional<ShaderContainer> PS ;
2010+ for (auto &Shader : P.Shaders ) {
2011+ ShaderContainer SC = {};
2012+ SC .EntryPoint = Shader.Entry ;
2013+ SC .Shader = Shader.Shader .get ();
2014+ if (Shader.Stage == Stages::Amplification)
2015+ AS = std::move (SC );
2016+ else if (Shader.Stage == Stages::Mesh)
2017+ MS = std::move (SC );
2018+ else if (Shader.Stage == Stages::Pixel)
2019+ PS = std::move (SC );
2020+ }
2021+
2022+ auto PipelineStateOrErr =
2023+ createPipelineAsMsPs (" Mesh Shader Pipeline State" , Bindings,
2024+ RTFormats, Format::D32FloatS8Uint, AS , MS , PS );
2025+ if (!PipelineStateOrErr)
2026+ return PipelineStateOrErr.takeError ();
2027+ IS .Pipeline = std::move (*PipelineStateOrErr);
2028+ llvm::outs () << " Mesh Shader Pipeline created.\n " ;
2029+ }
18402030
18412031 ColorAttachmentFormatDesc ColorAttachment = {};
18422032 ColorAttachment.Fmt = *FormatOrErr;
@@ -1878,7 +2068,13 @@ class MTLDevice : public offloadtest::Device {
18782068 virtual ~MTLDevice () {};
18792069
18802070private:
1881- void queryCapabilities () {}
2071+ void queryCapabilities () {
2072+ // GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
2073+ const bool MeshShaderSupported =
2074+ Device->supportsFamily (MTL ::GPUFamilyMetal3);
2075+ Caps.insert (std::make_pair (
2076+ " MeshShader" , makeCapability<bool >(" MeshShader" , MeshShaderSupported)));
2077+ }
18822078};
18832079} // namespace
18842080
0 commit comments