@@ -258,6 +258,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
258258 MTL ::DepthStencilState *DepthStencilState = nullptr ;
259259 MTL ::CullMode CullMode = MTL ::CullModeNone;
260260
261+ MTL ::Size MeshThreadsPerThreadgroup{1 , 1 , 1 };
262+ MTL ::Size ObjectThreadsPerThreadgroup{1 , 1 , 1 };
263+
261264 MTLPipelineState (llvm::StringRef Name, IRRootSignaturePtr RootSig,
262265 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
263266 MTL ::ComputePipelineState *ComputePipeline,
@@ -270,11 +273,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
270273 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
271274 MTL ::RenderPipelineState *RenderPipeline,
272275 MTL ::DepthStencilState *DepthStencilState,
273- MTL ::CullMode CullMode)
276+ MTL ::CullMode CullMode,
277+ MTL ::Size MeshThreadsPerThreadgroup = {1 , 1 , 1 },
278+ MTL ::Size ObjectThreadsPerThreadgroup = {1 , 1 , 1 })
274279 : offloadtest::PipelineState(GPUAPI ::Metal), Name(Name),
275280 RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
276281 RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
277- CullMode(CullMode) {}
282+ CullMode(CullMode),
283+ MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
284+ ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}
278285
279286 ~MTLPipelineState () override {
280287 if (ComputePipeline)
@@ -714,13 +721,30 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
714721 llvm::Error dispatchMesh (const offloadtest::PipelineState &PSO ,
715722 uint32_t GroupCountX, uint32_t GroupCountY,
716723 uint32_t GroupCountZ) override {
717- (void )PSO ;
718- (void )GroupCountX;
719- (void )GroupCountY;
720- (void )GroupCountZ;
724+ if (!ViewportSet)
725+ return llvm::createStringError (std::errc::invalid_argument,
726+ " Viewport must be set before drawing." );
727+ if (!ScissorSet)
728+ return llvm::createStringError (std::errc::invalid_argument,
729+ " Scissor must be set before drawing." );
721730
722- return llvm::createStringError (
723- " dispatchMesh is unimplemented in the Metal backend." );
731+ const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO );
732+ if (!MTLPSO .RenderPipeline )
733+ return llvm::createStringError (
734+ std::errc::invalid_argument,
735+ " PipelineState bound to dispatchMesh() is not a render pipeline." );
736+ RenderEnc->setRenderPipelineState (MTLPSO .RenderPipeline );
737+ if (MTLPSO .DepthStencilState )
738+ RenderEnc->setDepthStencilState (MTLPSO .DepthStencilState );
739+ RenderEnc->setCullMode (MTLPSO .CullMode );
740+ // Match the DX/VK convention (CCW = front) hardcoded in those backends.
741+ RenderEnc->setFrontFacingWinding (MTL ::WindingCounterClockwise);
742+
743+ RenderEnc->drawMeshThreadgroups (
744+ MTL::Size (GroupCountX, GroupCountY, GroupCountZ),
745+ MTLPSO .ObjectThreadsPerThreadgroup , MTLPSO .MeshThreadsPerThreadgroup );
746+
747+ return llvm::Error::success ();
724748 }
725749
726750 void endEncodingImpl () override {
@@ -1405,12 +1429,23 @@ class MTLDevice : public offloadtest::Device {
14051429 Scissor.Height = static_cast <uint32_t >(Height);
14061430 Encoder.setScissor (Scissor);
14071431
1408- if (IS .VB )
1409- Encoder.setVertexBuffer (0 , IS .VB .get (), 0 , P.Bindings .getVertexStride ());
1432+ if (P.isTraditionalRaster ()) {
1433+ if (IS .VB )
1434+ Encoder.setVertexBuffer (0 , IS .VB .get (), 0 ,
1435+ P.Bindings .getVertexStride ());
1436+
1437+ if (auto Err =
1438+ Encoder.drawInstanced (*IS .Pipeline .get (), P.getVertexCount (),
1439+ /* InstanceCount=*/ 1 ))
1440+ return Err;
1441+ } else {
1442+ if (auto Err = Encoder.dispatchMesh (
1443+ *IS .Pipeline .get (), P.DispatchParameters .DispatchGroupCount [0 ],
1444+ P.DispatchParameters .DispatchGroupCount [1 ],
1445+ P.DispatchParameters .DispatchGroupCount [2 ]))
1446+ return Err;
1447+ }
14101448
1411- if (auto Err = Encoder.drawInstanced (*IS .Pipeline .get (), P.getVertexCount (),
1412- /* InstanceCount=*/ 1 ))
1413- return Err;
14141449 Encoder.endEncoding ();
14151450
14161451 // Blit the render target into the readback buffer for CPU access.
@@ -1769,6 +1804,136 @@ class MTLDevice : public offloadtest::Device {
17691804 DSState, MTL ::CullModeNone);
17701805 }
17711806
1807+ llvm::Expected<std::unique_ptr<PipelineState>>
1808+ createPipelineAsMsPs (llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1809+ llvm::ArrayRef<Format> RTFormats,
1810+ std::optional<Format> DSFormat,
1811+ std::optional<ShaderContainer> AS , ShaderContainer MS ,
1812+ std::optional<ShaderContainer> PS ) {
1813+ IRRootSignaturePtr RootSig;
1814+ std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
1815+ if (auto Err = createRootSignature (BindingsDesc, /* IsGraphics=*/ true ,
1816+ RootSig, ArgBuffer))
1817+ return Err;
1818+
1819+ NS ::Error *Error = nullptr ;
1820+ auto compileStage = [&](Stages Stage, const ShaderContainer &SC ,
1821+ llvm::StringRef RoleName, MetalIR &OutIR,
1822+ MTLPtr<MTL ::Library> &OutLib,
1823+ MTLPtr<MTL ::Function> &OutFn) -> llvm::Error {
1824+ auto IROrErr =
1825+ convertToMetalIR (Stage, /* IsGraphics=*/ true , RootSig.get (), SC );
1826+ if (!IROrErr)
1827+ return IROrErr.takeError ();
1828+ OutIR = std::move (*IROrErr);
1829+
1830+ dispatch_data_t Data = IRMetalLibGetBytecodeData (OutIR.Binary .get ());
1831+ NS ::Error *Err = nullptr ;
1832+ OutLib = MTLPtr<MTL ::Library>(Device->newLibrary (Data, &Err));
1833+ if (Err)
1834+ return toError (Err);
1835+
1836+ OutFn = MTLPtr<MTL ::Function>(OutLib->newFunction (
1837+ NS::String::string (SC .EntryPoint.c_str(), NS::UTF8StringEncoding)));
1838+ if (!OutFn)
1839+ return llvm::createStringError (
1840+ std::errc::invalid_argument,
1841+ " Failed to find %s entry point '%s' in Metal library." ,
1842+ RoleName.data (), SC .EntryPoint .c_str ());
1843+ return llvm::Error::success ();
1844+ };
1845+
1846+ MetalIR MSIR ;
1847+ MTLPtr<MTL ::Library> MSLib;
1848+ MTLPtr<MTL ::Function> MSFn;
1849+ if (auto Err = compileStage (Stages::Mesh, MS , " mesh" , MSIR , MSLib, MSFn))
1850+ return Err;
1851+
1852+ MetalIR ASIR ;
1853+ MTLPtr<MTL ::Library> ASLib;
1854+ MTLPtr<MTL ::Function> ASFn;
1855+ if (AS ) {
1856+ if (auto Err = compileStage (Stages::Amplification, *AS , " amplification" ,
1857+ ASIR , ASLib, ASFn))
1858+ return Err;
1859+ }
1860+
1861+ MetalIR PSIR ;
1862+ MTLPtr<MTL ::Library> PSLib;
1863+ MTLPtr<MTL ::Function> PSFn;
1864+ if (PS ) {
1865+ if (auto Err =
1866+ compileStage (Stages::Pixel, *PS , " fragment" , PSIR , PSLib, PSFn))
1867+ return Err;
1868+ }
1869+
1870+ MTL ::MeshRenderPipelineDescriptor *Desc =
1871+ MTL::MeshRenderPipelineDescriptor::alloc ()->init ();
1872+ auto DescScope = llvm::scope_exit ([&] { Desc->release (); });
1873+
1874+ Desc->setMeshFunction (MSFn.get ());
1875+ if (ASFn)
1876+ Desc->setObjectFunction (ASFn.get ());
1877+ if (PSFn)
1878+ Desc->setFragmentFunction (PSFn.get ());
1879+
1880+ for (size_t I = 0 ; I < RTFormats.size (); ++I) {
1881+ MTL ::RenderPipelineColorAttachmentDescriptor *RPCA =
1882+ MTL::RenderPipelineColorAttachmentDescriptor::alloc ()->init ();
1883+ RPCA ->setPixelFormat (getMetalPixelFormat (RTFormats[I]));
1884+ Desc->colorAttachments ()->setObject (RPCA , I);
1885+ RPCA ->release ();
1886+ }
1887+
1888+ if (DSFormat) {
1889+ const MTL ::PixelFormat DSPixelFormat = getMetalPixelFormat (*DSFormat);
1890+ Desc->setDepthAttachmentPixelFormat (DSPixelFormat);
1891+ if (isStencilFormat (*DSFormat))
1892+ Desc->setStencilAttachmentPixelFormat (DSPixelFormat);
1893+ }
1894+
1895+ MTL ::RenderPipelineState *PSO = Device->newRenderPipelineState (
1896+ Desc, MTL ::PipelineOptionNone, /* reflection=*/ nullptr , &Error);
1897+ if (Error)
1898+ return toError (Error);
1899+
1900+ MTL ::DepthStencilDescriptor *DSDesc =
1901+ MTL::DepthStencilDescriptor::alloc ()->init ();
1902+ DSDesc->setDepthCompareFunction (MTL ::CompareFunctionLess);
1903+ DSDesc->setDepthWriteEnabled (true );
1904+ MTL ::DepthStencilState *DSState = Device->newDepthStencilState (DSDesc);
1905+ DSDesc->release ();
1906+
1907+ // Pull threads-per-threadgroup from shader reflection.
1908+ MTL ::Size MeshTGSize (1 , 1 , 1 );
1909+ {
1910+ IRVersionedMSInfo MSInfo;
1911+ if (IRShaderReflectionCopyMeshInfo (MSIR .Reflection .get (),
1912+ IRReflectionVersion_1_0, &MSInfo)) {
1913+ MeshTGSize = MTL::Size (MSInfo.info_1_0 .num_threads [0 ],
1914+ MSInfo.info_1_0 .num_threads [1 ],
1915+ MSInfo.info_1_0 .num_threads [2 ]);
1916+ }
1917+ IRShaderReflectionReleaseMeshInfo (&MSInfo);
1918+ }
1919+
1920+ MTL ::Size ObjectTGSize (1 , 1 , 1 );
1921+ if (AS ) {
1922+ IRVersionedASInfo ASInfo;
1923+ if (IRShaderReflectionCopyAmplificationInfo (
1924+ ASIR .Reflection .get (), IRReflectionVersion_1_0, &ASInfo)) {
1925+ ObjectTGSize = MTL::Size (ASInfo.info_1_0 .num_threads [0 ],
1926+ ASInfo.info_1_0 .num_threads [1 ],
1927+ ASInfo.info_1_0 .num_threads [2 ]);
1928+ }
1929+ IRShaderReflectionReleaseAmplificationInfo (&ASInfo);
1930+ }
1931+
1932+ return std::make_unique<MTLPipelineState>(
1933+ Name, std::move (RootSig), std::move (ArgBuffer), PSO , DSState,
1934+ MTL ::CullModeNone, MeshTGSize, ObjectTGSize);
1935+ }
1936+
17721937 llvm::Error executeProgram (Pipeline &P) override {
17731938 InvocationState IS ;
17741939
@@ -1817,40 +1982,68 @@ class MTLDevice : public offloadtest::Device {
18171982
18181983 if (auto Err = createComputeCommands (P, IS ))
18191984 return Err;
1820- } else {
1821- TraditionalRasterPipelineCreateDesc PipelineDesc = {};
1822- PipelineDesc.Topology = P.Bindings .Topology ;
1823- PipelineDesc.DSFormat = Format::D32FloatS8Uint;
1824- for (auto &Shader : P.Shaders ) {
1825- ShaderContainer SC = {};
1826- SC .EntryPoint = Shader.Entry ;
1827- SC .Shader = Shader.Shader .get ();
1828- PipelineDesc.setShader (Shader.Stage , std::move (SC ));
1829- }
1830-
1831- for (auto &Attr : P.Bindings .VertexAttributes ) {
1832- auto FormatOrErr = toFormat (Attr.Format , Attr.Channels );
1833- if (!FormatOrErr)
1834- return FormatOrErr.takeError ();
1835-
1836- InputLayoutDesc Layout = {};
1837- Layout.Name = Attr.Name ;
1838- Layout.Fmt = *FormatOrErr;
1839- Layout.OffsetInBytes = Attr.Offset ;
1840- PipelineDesc.InputLayout .push_back (Layout);
1841- }
1842-
1985+ } else if (P.isRaster ()) {
18431986 auto FormatOrErr = toFormat (P.Bindings .RTargetBufferPtr ->Format ,
18441987 P.Bindings .RTargetBufferPtr ->Channels );
1988+
1989+ llvm::SmallVector<Format> RTFormats;
18451990 if (!FormatOrErr)
18461991 return FormatOrErr.takeError ();
1847- PipelineDesc.RTFormats .push_back (*FormatOrErr);
1992+ RTFormats.push_back (*FormatOrErr);
1993+
1994+ if (P.isTraditionalRaster ()) {
1995+ TraditionalRasterPipelineCreateDesc PipelineDesc = {};
1996+ PipelineDesc.Topology = P.Bindings .Topology ;
1997+ PipelineDesc.DSFormat = Format::D32FloatS8Uint;
1998+ PipelineDesc.RTFormats = RTFormats;
1999+ for (auto &Shader : P.Shaders ) {
2000+ ShaderContainer SC = {};
2001+ SC .EntryPoint = Shader.Entry ;
2002+ SC .Shader = Shader.Shader .get ();
2003+ PipelineDesc.setShader (Shader.Stage , std::move (SC ));
2004+ }
18482005
1849- auto PipelineStateOrErr = createTraditionalRasterPipeline (
1850- " Graphics Pipeline State" , Bindings, PipelineDesc);
1851- if (!PipelineStateOrErr)
1852- return PipelineStateOrErr.takeError ();
1853- IS .Pipeline = std::move (*PipelineStateOrErr);
2006+ for (auto &Attr : P.Bindings .VertexAttributes ) {
2007+ auto FormatOrErr = toFormat (Attr.Format , Attr.Channels );
2008+ if (!FormatOrErr)
2009+ return FormatOrErr.takeError ();
2010+
2011+ InputLayoutDesc Layout = {};
2012+ Layout.Name = Attr.Name ;
2013+ Layout.Fmt = *FormatOrErr;
2014+ Layout.OffsetInBytes = Attr.Offset ;
2015+ PipelineDesc.InputLayout .push_back (Layout);
2016+ }
2017+
2018+ auto PipelineStateOrErr = createTraditionalRasterPipeline (
2019+ " Graphics Pipeline State" , Bindings, PipelineDesc);
2020+ if (!PipelineStateOrErr)
2021+ return PipelineStateOrErr.takeError ();
2022+ IS .Pipeline = std::move (*PipelineStateOrErr);
2023+ } else if (P.isMeshShaderRaster ()) {
2024+ std::optional<ShaderContainer> AS ;
2025+ ShaderContainer MS = {};
2026+ std::optional<ShaderContainer> PS ;
2027+ for (auto &Shader : P.Shaders ) {
2028+ ShaderContainer SC = {};
2029+ SC .EntryPoint = Shader.Entry ;
2030+ SC .Shader = Shader.Shader .get ();
2031+ if (Shader.Stage == Stages::Amplification)
2032+ AS = std::move (SC );
2033+ else if (Shader.Stage == Stages::Mesh)
2034+ MS = std::move (SC );
2035+ else if (Shader.Stage == Stages::Pixel)
2036+ PS = std::move (SC );
2037+ }
2038+
2039+ auto PipelineStateOrErr =
2040+ createPipelineAsMsPs (" Mesh Shader Pipeline State" , Bindings,
2041+ RTFormats, Format::D32FloatS8Uint, AS , MS , PS );
2042+ if (!PipelineStateOrErr)
2043+ return PipelineStateOrErr.takeError ();
2044+ IS .Pipeline = std::move (*PipelineStateOrErr);
2045+ llvm::outs () << " Mesh Shader Pipeline created.\n " ;
2046+ }
18542047
18552048 ColorAttachmentFormatDesc ColorAttachment = {};
18562049 ColorAttachment.Fmt = *FormatOrErr;
@@ -1892,7 +2085,13 @@ class MTLDevice : public offloadtest::Device {
18922085 virtual ~MTLDevice () {};
18932086
18942087private:
1895- void queryCapabilities () {}
2088+ void queryCapabilities () {
2089+ // GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
2090+ const bool MeshShaderSupported =
2091+ Device->supportsFamily (MTL ::GPUFamilyMetal3);
2092+ Caps.insert (std::make_pair (
2093+ " MeshShader" , makeCapability<bool >(" MeshShader" , MeshShaderSupported)));
2094+ }
18962095};
18972096} // namespace
18982097
0 commit comments