@@ -254,6 +254,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
254254 MTL ::DepthStencilState *DepthStencilState = nullptr ;
255255 MTL ::CullMode CullMode = MTL ::CullModeNone;
256256
257+ MTL ::Size MeshThreadsPerThreadgroup{1 , 1 , 1 };
258+ MTL ::Size ObjectThreadsPerThreadgroup{1 , 1 , 1 };
259+
257260 MTLPipelineState (llvm::StringRef Name, IRRootSignaturePtr RootSig,
258261 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
259262 MTL ::ComputePipelineState *ComputePipeline,
@@ -266,11 +269,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
266269 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
267270 MTL ::RenderPipelineState *RenderPipeline,
268271 MTL ::DepthStencilState *DepthStencilState,
269- MTL ::CullMode CullMode)
272+ MTL ::CullMode CullMode,
273+ MTL ::Size MeshThreadsPerThreadgroup = {1 , 1 , 1 },
274+ MTL ::Size ObjectThreadsPerThreadgroup = {1 , 1 , 1 })
270275 : offloadtest::PipelineState(GPUAPI ::Metal), Name(Name),
271276 RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
272277 RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
273- CullMode(CullMode) {}
278+ CullMode(CullMode),
279+ MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
280+ ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}
274281
275282 ~MTLPipelineState () override {
276283 if (ComputePipeline)
@@ -710,13 +717,30 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
710717 llvm::Error dispatchMesh (const offloadtest::PipelineState &PSO ,
711718 uint32_t GroupCountX, uint32_t GroupCountY,
712719 uint32_t GroupCountZ) override {
713- (void )PSO ;
714- (void )GroupCountX;
715- (void )GroupCountY;
716- (void )GroupCountZ;
720+ if (!ViewportSet)
721+ return llvm::createStringError (std::errc::invalid_argument,
722+ " Viewport must be set before drawing." );
723+ if (!ScissorSet)
724+ return llvm::createStringError (std::errc::invalid_argument,
725+ " Scissor must be set before drawing." );
717726
718- return llvm::createStringError (
719- " dispatchMesh is unimplemented in the Metal backend." );
727+ const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO );
728+ if (!MTLPSO .RenderPipeline )
729+ return llvm::createStringError (
730+ std::errc::invalid_argument,
731+ " PipelineState bound to dispatchMesh() is not a render pipeline." );
732+ RenderEnc->setRenderPipelineState (MTLPSO .RenderPipeline );
733+ if (MTLPSO .DepthStencilState )
734+ RenderEnc->setDepthStencilState (MTLPSO .DepthStencilState );
735+ RenderEnc->setCullMode (MTLPSO .CullMode );
736+ // Match the DX/VK convention (CCW = front) hardcoded in those backends.
737+ RenderEnc->setFrontFacingWinding (MTL ::WindingCounterClockwise);
738+
739+ RenderEnc->drawMeshThreadgroups (
740+ MTL::Size (GroupCountX, GroupCountY, GroupCountZ),
741+ MTLPSO .ObjectThreadsPerThreadgroup , MTLPSO .MeshThreadsPerThreadgroup );
742+
743+ return llvm::Error::success ();
720744 }
721745
722746 void endEncodingImpl () override {
@@ -1401,12 +1425,23 @@ class MTLDevice : public offloadtest::Device {
14011425 Scissor.Height = static_cast <uint32_t >(Height);
14021426 Encoder.setScissor (Scissor);
14031427
1404- if (IS .VB )
1405- Encoder.setVertexBuffer (0 , IS .VB .get (), 0 , P.Bindings .getVertexStride ());
1428+ if (P.isTraditionalRaster ()) {
1429+ if (IS .VB )
1430+ Encoder.setVertexBuffer (0 , IS .VB .get (), 0 ,
1431+ P.Bindings .getVertexStride ());
1432+
1433+ if (auto Err =
1434+ Encoder.drawInstanced (*IS .Pipeline .get (), P.getVertexCount (),
1435+ /* InstanceCount=*/ 1 ))
1436+ return Err;
1437+ } else {
1438+ if (auto Err = Encoder.dispatchMesh (
1439+ *IS .Pipeline .get (), P.DispatchParameters .DispatchGroupCount [0 ],
1440+ P.DispatchParameters .DispatchGroupCount [1 ],
1441+ P.DispatchParameters .DispatchGroupCount [2 ]))
1442+ return Err;
1443+ }
14061444
1407- if (auto Err = Encoder.drawInstanced (*IS .Pipeline .get (), P.getVertexCount (),
1408- /* InstanceCount=*/ 1 ))
1409- return Err;
14101445 Encoder.endEncoding ();
14111446
14121447 // Blit the render target into the readback buffer for CPU access.
@@ -1760,6 +1795,136 @@ class MTLDevice : public offloadtest::Device {
17601795 DSState, MTL ::CullModeNone);
17611796 }
17621797
1798+ llvm::Expected<std::unique_ptr<PipelineState>>
1799+ createPipelineAsMsPs (llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1800+ llvm::ArrayRef<Format> RTFormats,
1801+ std::optional<Format> DSFormat,
1802+ std::optional<ShaderContainer> AS , ShaderContainer MS ,
1803+ std::optional<ShaderContainer> PS ) {
1804+ IRRootSignaturePtr RootSig;
1805+ std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
1806+ if (auto Err = createRootSignature (BindingsDesc, /* IsGraphics=*/ true ,
1807+ RootSig, ArgBuffer))
1808+ return Err;
1809+
1810+ NS ::Error *Error = nullptr ;
1811+ auto compileStage = [&](Stages Stage, const ShaderContainer &SC ,
1812+ llvm::StringRef RoleName, MetalIR &OutIR,
1813+ MTLPtr<MTL ::Library> &OutLib,
1814+ MTLPtr<MTL ::Function> &OutFn) -> llvm::Error {
1815+ auto IROrErr =
1816+ convertToMetalIR (Stage, /* IsGraphics=*/ true , RootSig.get (), SC );
1817+ if (!IROrErr)
1818+ return IROrErr.takeError ();
1819+ OutIR = std::move (*IROrErr);
1820+
1821+ dispatch_data_t Data = IRMetalLibGetBytecodeData (OutIR.Binary .get ());
1822+ NS ::Error *Err = nullptr ;
1823+ OutLib = MTLPtr<MTL ::Library>(Device->newLibrary (Data, &Err));
1824+ if (Err)
1825+ return toError (Err);
1826+
1827+ OutFn = MTLPtr<MTL ::Function>(OutLib->newFunction (
1828+ NS::String::string (SC .EntryPoint.c_str(), NS::UTF8StringEncoding)));
1829+ if (!OutFn)
1830+ return llvm::createStringError (
1831+ std::errc::invalid_argument,
1832+ " Failed to find %s entry point '%s' in Metal library." ,
1833+ RoleName.data (), SC .EntryPoint .c_str ());
1834+ return llvm::Error::success ();
1835+ };
1836+
1837+ MetalIR MSIR ;
1838+ MTLPtr<MTL ::Library> MSLib;
1839+ MTLPtr<MTL ::Function> MSFn;
1840+ if (auto Err = compileStage (Stages::Mesh, MS , " mesh" , MSIR , MSLib, MSFn))
1841+ return Err;
1842+
1843+ MetalIR ASIR ;
1844+ MTLPtr<MTL ::Library> ASLib;
1845+ MTLPtr<MTL ::Function> ASFn;
1846+ if (AS ) {
1847+ if (auto Err = compileStage (Stages::Amplification, *AS , " amplification" ,
1848+ ASIR , ASLib, ASFn))
1849+ return Err;
1850+ }
1851+
1852+ MetalIR PSIR ;
1853+ MTLPtr<MTL ::Library> PSLib;
1854+ MTLPtr<MTL ::Function> PSFn;
1855+ if (PS ) {
1856+ if (auto Err =
1857+ compileStage (Stages::Pixel, *PS , " fragment" , PSIR , PSLib, PSFn))
1858+ return Err;
1859+ }
1860+
1861+ MTL ::MeshRenderPipelineDescriptor *Desc =
1862+ MTL::MeshRenderPipelineDescriptor::alloc ()->init ();
1863+ auto DescScope = llvm::scope_exit ([&] { Desc->release (); });
1864+
1865+ Desc->setMeshFunction (MSFn.get ());
1866+ if (ASFn)
1867+ Desc->setObjectFunction (ASFn.get ());
1868+ if (PSFn)
1869+ Desc->setFragmentFunction (PSFn.get ());
1870+
1871+ for (size_t I = 0 ; I < RTFormats.size (); ++I) {
1872+ MTL ::RenderPipelineColorAttachmentDescriptor *RPCA =
1873+ MTL::RenderPipelineColorAttachmentDescriptor::alloc ()->init ();
1874+ RPCA ->setPixelFormat (getMetalPixelFormat (RTFormats[I]));
1875+ Desc->colorAttachments ()->setObject (RPCA , I);
1876+ RPCA ->release ();
1877+ }
1878+
1879+ if (DSFormat) {
1880+ const MTL ::PixelFormat DSPixelFormat = getMetalPixelFormat (*DSFormat);
1881+ Desc->setDepthAttachmentPixelFormat (DSPixelFormat);
1882+ if (isStencilFormat (*DSFormat))
1883+ Desc->setStencilAttachmentPixelFormat (DSPixelFormat);
1884+ }
1885+
1886+ MTL ::RenderPipelineState *PSO = Device->newRenderPipelineState (
1887+ Desc, MTL ::PipelineOptionNone, /* reflection=*/ nullptr , &Error);
1888+ if (Error)
1889+ return toError (Error);
1890+
1891+ MTL ::DepthStencilDescriptor *DSDesc =
1892+ MTL::DepthStencilDescriptor::alloc ()->init ();
1893+ DSDesc->setDepthCompareFunction (MTL ::CompareFunctionLess);
1894+ DSDesc->setDepthWriteEnabled (true );
1895+ MTL ::DepthStencilState *DSState = Device->newDepthStencilState (DSDesc);
1896+ DSDesc->release ();
1897+
1898+ // Pull threads-per-threadgroup from shader reflection.
1899+ MTL ::Size MeshTGSize (1 , 1 , 1 );
1900+ {
1901+ IRVersionedMSInfo MSInfo;
1902+ if (IRShaderReflectionCopyMeshInfo (MSIR .Reflection .get (),
1903+ IRReflectionVersion_1_0, &MSInfo)) {
1904+ MeshTGSize = MTL::Size (MSInfo.info_1_0 .num_threads [0 ],
1905+ MSInfo.info_1_0 .num_threads [1 ],
1906+ MSInfo.info_1_0 .num_threads [2 ]);
1907+ }
1908+ IRShaderReflectionReleaseMeshInfo (&MSInfo);
1909+ }
1910+
1911+ MTL ::Size ObjectTGSize (1 , 1 , 1 );
1912+ if (AS ) {
1913+ IRVersionedASInfo ASInfo;
1914+ if (IRShaderReflectionCopyAmplificationInfo (
1915+ ASIR .Reflection .get (), IRReflectionVersion_1_0, &ASInfo)) {
1916+ ObjectTGSize = MTL::Size (ASInfo.info_1_0 .num_threads [0 ],
1917+ ASInfo.info_1_0 .num_threads [1 ],
1918+ ASInfo.info_1_0 .num_threads [2 ]);
1919+ }
1920+ IRShaderReflectionReleaseAmplificationInfo (&ASInfo);
1921+ }
1922+
1923+ return std::make_unique<MTLPipelineState>(
1924+ Name, std::move (RootSig), std::move (ArgBuffer), PSO , DSState,
1925+ MTL ::CullModeNone, MeshTGSize, ObjectTGSize);
1926+ }
1927+
17631928 llvm::Error executeProgram (Pipeline &P) override {
17641929 InvocationState IS ;
17651930
@@ -1808,40 +1973,68 @@ class MTLDevice : public offloadtest::Device {
18081973
18091974 if (auto Err = createComputeCommands (P, IS ))
18101975 return Err;
1811- } else {
1812- TraditionalRasterPipelineCreateDesc PipelineDesc = {};
1813- PipelineDesc.Topology = P.Bindings .Topology ;
1814- PipelineDesc.DSFormat = Format::D32FloatS8Uint;
1815- for (auto &Shader : P.Shaders ) {
1816- ShaderContainer SC = {};
1817- SC .EntryPoint = Shader.Entry ;
1818- SC .Shader = Shader.Shader .get ();
1819- PipelineDesc.setShader (Shader.Stage , std::move (SC ));
1820- }
1821-
1822- for (auto &Attr : P.Bindings .VertexAttributes ) {
1823- auto FormatOrErr = toFormat (Attr.Format , Attr.Channels );
1824- if (!FormatOrErr)
1825- return FormatOrErr.takeError ();
1826-
1827- InputLayoutDesc Layout = {};
1828- Layout.Name = Attr.Name ;
1829- Layout.Fmt = *FormatOrErr;
1830- Layout.OffsetInBytes = Attr.Offset ;
1831- PipelineDesc.InputLayout .push_back (Layout);
1832- }
1833-
1976+ } else if (P.isRaster ()) {
18341977 auto FormatOrErr = toFormat (P.Bindings .RTargetBufferPtr ->Format ,
18351978 P.Bindings .RTargetBufferPtr ->Channels );
18361979 if (!FormatOrErr)
18371980 return FormatOrErr.takeError ();
18381981 PipelineDesc.RTFormats .push_back (*FormatOrErr);
18391982
1840- auto PipelineStateOrErr = createTraditionalRasterPipeline (
1841- " Graphics Pipeline State" , Bindings, PipelineDesc);
1842- if (!PipelineStateOrErr)
1843- return PipelineStateOrErr.takeError ();
1844- IS .Pipeline = std::move (*PipelineStateOrErr);
1983+ llvm::SmallVector<Format> RTFormats;
1984+ RTFormats.push_back (*FormatOrErr);
1985+
1986+ if (P.isTraditionalRaster ()) {
1987+ TraditionalRasterPipelineCreateDesc PipelineDesc = {};
1988+ PipelineDesc.Topology = P.Bindings .Topology ;
1989+ PipelineDesc.DSFormat = Format::D32FloatS8Uint;
1990+ for (auto &Shader : P.Shaders ) {
1991+ ShaderContainer SC = {};
1992+ SC .EntryPoint = Shader.Entry ;
1993+ SC .Shader = Shader.Shader .get ();
1994+ PipelineDesc.setShader (Shader.Stage , std::move (SC ));
1995+ }
1996+
1997+ for (auto &Attr : P.Bindings .VertexAttributes ) {
1998+ auto FormatOrErr = toFormat (Attr.Format , Attr.Channels );
1999+ if (!FormatOrErr)
2000+ return FormatOrErr.takeError ();
2001+
2002+ InputLayoutDesc Layout = {};
2003+ Layout.Name = Attr.Name ;
2004+ Layout.Fmt = *FormatOrErr;
2005+ Layout.OffsetInBytes = Attr.Offset ;
2006+ PipelineDesc.InputLayout .push_back (Layout);
2007+ }
2008+
2009+ auto PipelineStateOrErr = createTraditionalRasterPipeline (
2010+ " Graphics Pipeline State" , Bindings, PipelineDesc);
2011+ if (!PipelineStateOrErr)
2012+ return PipelineStateOrErr.takeError ();
2013+ IS .Pipeline = std::move (*PipelineStateOrErr);
2014+ } else if (P.isMeshShaderRaster ()) {
2015+ std::optional<ShaderContainer> AS ;
2016+ ShaderContainer MS = {};
2017+ std::optional<ShaderContainer> PS ;
2018+ for (auto &Shader : P.Shaders ) {
2019+ ShaderContainer SC = {};
2020+ SC .EntryPoint = Shader.Entry ;
2021+ SC .Shader = Shader.Shader .get ();
2022+ if (Shader.Stage == Stages::Amplification)
2023+ AS = std::move (SC );
2024+ else if (Shader.Stage == Stages::Mesh)
2025+ MS = std::move (SC );
2026+ else if (Shader.Stage == Stages::Pixel)
2027+ PS = std::move (SC );
2028+ }
2029+
2030+ auto PipelineStateOrErr =
2031+ createPipelineAsMsPs (" Mesh Shader Pipeline State" , Bindings,
2032+ RTFormats, Format::D32FloatS8Uint, AS , MS , PS );
2033+ if (!PipelineStateOrErr)
2034+ return PipelineStateOrErr.takeError ();
2035+ IS .Pipeline = std::move (*PipelineStateOrErr);
2036+ llvm::outs () << " Mesh Shader Pipeline created.\n " ;
2037+ }
18452038
18462039 ColorAttachmentFormatDesc ColorAttachment = {};
18472040 ColorAttachment.Fmt = *FormatOrErr;
@@ -1883,7 +2076,13 @@ class MTLDevice : public offloadtest::Device {
18832076 virtual ~MTLDevice () {};
18842077
18852078private:
1886- void queryCapabilities () {}
2079+ void queryCapabilities () {
2080+ // GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
2081+ const bool MeshShaderSupported =
2082+ Device->supportsFamily (MTL ::GPUFamilyMetal3);
2083+ Caps.insert (std::make_pair (
2084+ " MeshShader" , makeCapability<bool >(" MeshShader" , MeshShaderSupported)));
2085+ }
18872086};
18882087} // namespace
18892088
0 commit comments