@@ -248,6 +248,9 @@ class MTLPipelineState : public offloadtest::PipelineState {
248248 MTL ::DepthStencilState *DepthStencilState = nullptr ;
249249 MTL ::CullMode CullMode = MTL ::CullModeNone;
250250
251+ MTL ::Size MeshThreadsPerThreadgroup{1 , 1 , 1 };
252+ MTL ::Size ObjectThreadsPerThreadgroup{1 , 1 , 1 };
253+
251254 MTLPipelineState (llvm::StringRef Name, IRRootSignaturePtr RootSig,
252255 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
253256 IRShaderReflectionPtr Reflection,
@@ -260,11 +263,15 @@ class MTLPipelineState : public offloadtest::PipelineState {
260263 std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
261264 MTL ::RenderPipelineState *RenderPipeline,
262265 MTL ::DepthStencilState *DepthStencilState,
263- MTL ::CullMode CullMode)
266+ MTL ::CullMode CullMode,
267+ MTL ::Size MeshThreadsPerThreadgroup = {1 , 1 , 1 },
268+ MTL ::Size ObjectThreadsPerThreadgroup = {1 , 1 , 1 })
264269 : offloadtest::PipelineState(GPUAPI ::Metal), Name(Name),
265270 RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
266271 RenderPipeline(RenderPipeline), DepthStencilState(DepthStencilState),
267- CullMode(CullMode) {}
272+ CullMode(CullMode),
273+ MeshThreadsPerThreadgroup(MeshThreadsPerThreadgroup),
274+ ObjectThreadsPerThreadgroup(ObjectThreadsPerThreadgroup) {}
268275
269276 ~MTLPipelineState () override {
270277 if (ComputePipeline)
@@ -708,13 +715,30 @@ class MTLRenderEncoder : public offloadtest::RenderEncoder {
708715 llvm::Error dispatchMesh (const offloadtest::PipelineState &PSO ,
709716 uint32_t GroupCountX, uint32_t GroupCountY,
710717 uint32_t GroupCountZ) override {
711- (void )PSO ;
712- (void )GroupCountX;
713- (void )GroupCountY;
714- (void )GroupCountZ;
718+ if (!ViewportSet)
719+ return llvm::createStringError (std::errc::invalid_argument,
720+ " Viewport must be set before drawing." );
721+ if (!ScissorSet)
722+ return llvm::createStringError (std::errc::invalid_argument,
723+ " Scissor must be set before drawing." );
715724
716- return llvm::createStringError (
717- " dispatchMesh is unimplemented in the Metal backend." );
725+ const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO );
726+ if (!MTLPSO .RenderPipeline )
727+ return llvm::createStringError (
728+ std::errc::invalid_argument,
729+ " PipelineState bound to dispatchMesh() is not a render pipeline." );
730+ RenderEnc->setRenderPipelineState (MTLPSO .RenderPipeline );
731+ if (MTLPSO .DepthStencilState )
732+ RenderEnc->setDepthStencilState (MTLPSO .DepthStencilState );
733+ RenderEnc->setCullMode (MTLPSO .CullMode );
734+ // Match the DX/VK convention (CCW = front) hardcoded in those backends.
735+ RenderEnc->setFrontFacingWinding (MTL ::WindingCounterClockwise);
736+
737+ RenderEnc->drawMeshThreadgroups (
738+ MTL::Size (GroupCountX, GroupCountY, GroupCountZ),
739+ MTLPSO .ObjectThreadsPerThreadgroup , MTLPSO .MeshThreadsPerThreadgroup );
740+
741+ return llvm::Error::success ();
718742 }
719743
720744 void endEncodingImpl () override {
@@ -1413,12 +1437,23 @@ class MTLDevice : public offloadtest::Device {
14131437 Scissor.Height = static_cast <uint32_t >(Height);
14141438 Encoder.setScissor (Scissor);
14151439
1416- if (IS .VB )
1417- Encoder.setVertexBuffer (0 , IS .VB .get (), 0 , P.Bindings .getVertexStride ());
1440+ if (P.isTraditionalRaster ()) {
1441+ if (IS .VB )
1442+ Encoder.setVertexBuffer (0 , IS .VB .get (), 0 ,
1443+ P.Bindings .getVertexStride ());
1444+
1445+ if (auto Err =
1446+ Encoder.drawInstanced (*IS .Pipeline .get (), P.getVertexCount (),
1447+ /* InstanceCount=*/ 1 ))
1448+ return Err;
1449+ } else {
1450+ if (auto Err = Encoder.dispatchMesh (
1451+ *IS .Pipeline .get (), P.DispatchParameters .DispatchGroupCount [0 ],
1452+ P.DispatchParameters .DispatchGroupCount [1 ],
1453+ P.DispatchParameters .DispatchGroupCount [2 ]))
1454+ return Err;
1455+ }
14181456
1419- if (auto Err = Encoder.drawInstanced (*IS .Pipeline .get (), P.getVertexCount (),
1420- /* InstanceCount=*/ 1 ))
1421- return Err;
14221457 Encoder.endEncoding ();
14231458
14241459 // Blit the render target into the readback buffer for CPU access.
@@ -1749,6 +1784,136 @@ class MTLDevice : public offloadtest::Device {
17491784 DSState, MTL ::CullModeNone);
17501785 }
17511786
1787+ llvm::Expected<std::unique_ptr<PipelineState>>
1788+ createPipelineAsMsPs (llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1789+ llvm::ArrayRef<Format> RTFormats,
1790+ std::optional<Format> DSFormat,
1791+ std::optional<ShaderContainer> AS , ShaderContainer MS ,
1792+ std::optional<ShaderContainer> PS ) {
1793+ IRRootSignaturePtr RootSig;
1794+ std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
1795+ if (auto Err = createRootSignature (BindingsDesc, /* IsGraphics=*/ true ,
1796+ RootSig, ArgBuffer))
1797+ return Err;
1798+
1799+ NS ::Error *Error = nullptr ;
1800+ auto compileStage = [&](Stages Stage, const ShaderContainer &SC ,
1801+ llvm::StringRef RoleName, MetalIR &OutIR,
1802+ MTLPtr<MTL ::Library> &OutLib,
1803+ MTLPtr<MTL ::Function> &OutFn) -> llvm::Error {
1804+ auto IROrErr =
1805+ convertToMetalIR (Stage, /* IsGraphics=*/ true , RootSig.get (), SC );
1806+ if (!IROrErr)
1807+ return IROrErr.takeError ();
1808+ OutIR = std::move (*IROrErr);
1809+
1810+ dispatch_data_t Data = IRMetalLibGetBytecodeData (OutIR.Binary .get ());
1811+ NS ::Error *Err = nullptr ;
1812+ OutLib = MTLPtr<MTL ::Library>(Device->newLibrary (Data, &Err));
1813+ if (Err)
1814+ return toError (Err);
1815+
1816+ OutFn = MTLPtr<MTL ::Function>(OutLib->newFunction (
1817+ NS::String::string (SC .EntryPoint.c_str(), NS::UTF8StringEncoding)));
1818+ if (!OutFn)
1819+ return llvm::createStringError (
1820+ std::errc::invalid_argument,
1821+ " Failed to find %s entry point '%s' in Metal library." ,
1822+ RoleName.data (), SC .EntryPoint .c_str ());
1823+ return llvm::Error::success ();
1824+ };
1825+
1826+ MetalIR MSIR ;
1827+ MTLPtr<MTL ::Library> MSLib;
1828+ MTLPtr<MTL ::Function> MSFn;
1829+ if (auto Err = compileStage (Stages::Mesh, MS , " mesh" , MSIR , MSLib, MSFn))
1830+ return Err;
1831+
1832+ MetalIR ASIR ;
1833+ MTLPtr<MTL ::Library> ASLib;
1834+ MTLPtr<MTL ::Function> ASFn;
1835+ if (AS ) {
1836+ if (auto Err = compileStage (Stages::Amplification, *AS , " amplification" ,
1837+ ASIR , ASLib, ASFn))
1838+ return Err;
1839+ }
1840+
1841+ MetalIR PSIR ;
1842+ MTLPtr<MTL ::Library> PSLib;
1843+ MTLPtr<MTL ::Function> PSFn;
1844+ if (PS ) {
1845+ if (auto Err =
1846+ compileStage (Stages::Pixel, *PS , " fragment" , PSIR , PSLib, PSFn))
1847+ return Err;
1848+ }
1849+
1850+ MTL ::MeshRenderPipelineDescriptor *Desc =
1851+ MTL::MeshRenderPipelineDescriptor::alloc ()->init ();
1852+ auto DescScope = llvm::scope_exit ([&] { Desc->release (); });
1853+
1854+ Desc->setMeshFunction (MSFn.get ());
1855+ if (ASFn)
1856+ Desc->setObjectFunction (ASFn.get ());
1857+ if (PSFn)
1858+ Desc->setFragmentFunction (PSFn.get ());
1859+
1860+ for (size_t I = 0 ; I < RTFormats.size (); ++I) {
1861+ MTL ::RenderPipelineColorAttachmentDescriptor *RPCA =
1862+ MTL::RenderPipelineColorAttachmentDescriptor::alloc ()->init ();
1863+ RPCA ->setPixelFormat (getMetalPixelFormat (RTFormats[I]));
1864+ Desc->colorAttachments ()->setObject (RPCA , I);
1865+ RPCA ->release ();
1866+ }
1867+
1868+ if (DSFormat) {
1869+ const MTL ::PixelFormat DSPixelFormat = getMetalPixelFormat (*DSFormat);
1870+ Desc->setDepthAttachmentPixelFormat (DSPixelFormat);
1871+ if (isStencilFormat (*DSFormat))
1872+ Desc->setStencilAttachmentPixelFormat (DSPixelFormat);
1873+ }
1874+
1875+ MTL ::RenderPipelineState *PSO = Device->newRenderPipelineState (
1876+ Desc, MTL ::PipelineOptionNone, /* reflection=*/ nullptr , &Error);
1877+ if (Error)
1878+ return toError (Error);
1879+
1880+ MTL ::DepthStencilDescriptor *DSDesc =
1881+ MTL::DepthStencilDescriptor::alloc ()->init ();
1882+ DSDesc->setDepthCompareFunction (MTL ::CompareFunctionLess);
1883+ DSDesc->setDepthWriteEnabled (true );
1884+ MTL ::DepthStencilState *DSState = Device->newDepthStencilState (DSDesc);
1885+ DSDesc->release ();
1886+
1887+ // Pull threads-per-threadgroup from shader reflection.
1888+ MTL ::Size MeshTGSize (1 , 1 , 1 );
1889+ {
1890+ IRVersionedMSInfo MSInfo;
1891+ if (IRShaderReflectionCopyMeshInfo (MSIR .Reflection .get (),
1892+ IRReflectionVersion_1_0, &MSInfo)) {
1893+ MeshTGSize = MTL::Size (MSInfo.info_1_0 .num_threads [0 ],
1894+ MSInfo.info_1_0 .num_threads [1 ],
1895+ MSInfo.info_1_0 .num_threads [2 ]);
1896+ }
1897+ IRShaderReflectionReleaseMeshInfo (&MSInfo);
1898+ }
1899+
1900+ MTL ::Size ObjectTGSize (1 , 1 , 1 );
1901+ if (AS ) {
1902+ IRVersionedASInfo ASInfo;
1903+ if (IRShaderReflectionCopyAmplificationInfo (
1904+ ASIR .Reflection .get (), IRReflectionVersion_1_0, &ASInfo)) {
1905+ ObjectTGSize = MTL::Size (ASInfo.info_1_0 .num_threads [0 ],
1906+ ASInfo.info_1_0 .num_threads [1 ],
1907+ ASInfo.info_1_0 .num_threads [2 ]);
1908+ }
1909+ IRShaderReflectionReleaseAmplificationInfo (&ASInfo);
1910+ }
1911+
1912+ return std::make_unique<MTLPipelineState>(
1913+ Name, std::move (RootSig), std::move (ArgBuffer), PSO , DSState,
1914+ MTL ::CullModeNone, MeshTGSize, ObjectTGSize);
1915+ }
1916+
17521917 llvm::Error executeProgram (Pipeline &P) override {
17531918 InvocationState IS ;
17541919
@@ -1797,32 +1962,7 @@ class MTLDevice : public offloadtest::Device {
17971962
17981963 if (auto Err = createComputeCommands (P, IS ))
17991964 return Err;
1800- } else {
1801- ShaderContainer VS = {};
1802- ShaderContainer PS = {};
1803- for (auto &Shader : P.Shaders ) {
1804- if (Shader.Stage == Stages::Vertex) {
1805- VS .EntryPoint = Shader.Entry ;
1806- VS .Shader = Shader.Shader .get ();
1807- } else if (Shader.Stage == Stages::Pixel) {
1808- PS .EntryPoint = Shader.Entry ;
1809- PS .Shader = Shader.Shader .get ();
1810- }
1811- }
1812-
1813- llvm::SmallVector<InputLayoutDesc> InputLayout;
1814- for (auto &Attr : P.Bindings .VertexAttributes ) {
1815- auto FormatOrErr = toFormat (Attr.Format , Attr.Channels );
1816- if (!FormatOrErr)
1817- return FormatOrErr.takeError ();
1818-
1819- InputLayoutDesc Desc = {};
1820- Desc.Name = Attr.Name ;
1821- Desc.Fmt = *FormatOrErr;
1822- Desc.OffsetInBytes = Attr.Offset ;
1823- InputLayout.push_back (Desc);
1824- }
1825-
1965+ } else if (P.isRaster ()) {
18261966 auto FormatOrErr = toFormat (P.Bindings .RTargetBufferPtr ->Format ,
18271967 P.Bindings .RTargetBufferPtr ->Channels );
18281968 if (!FormatOrErr)
@@ -1831,12 +1971,63 @@ class MTLDevice : public offloadtest::Device {
18311971 llvm::SmallVector<Format> RTFormats;
18321972 RTFormats.push_back (*FormatOrErr);
18331973
1834- auto PipelineStateOrErr =
1835- createPipelineVsPs (" Graphics Pipeline State" , Bindings, InputLayout,
1836- RTFormats, Format::D32FloatS8Uint, VS , PS );
1837- if (!PipelineStateOrErr)
1838- return PipelineStateOrErr.takeError ();
1839- IS .Pipeline = std::move (*PipelineStateOrErr);
1974+ if (P.isTraditionalRaster ()) {
1975+ ShaderContainer VS = {};
1976+ ShaderContainer PS = {};
1977+ for (auto &Shader : P.Shaders ) {
1978+ if (Shader.Stage == Stages::Vertex) {
1979+ VS .EntryPoint = Shader.Entry ;
1980+ VS .Shader = Shader.Shader .get ();
1981+ } else if (Shader.Stage == Stages::Pixel) {
1982+ PS .EntryPoint = Shader.Entry ;
1983+ PS .Shader = Shader.Shader .get ();
1984+ }
1985+ }
1986+
1987+ // Create the input layout based on the vertex attributes.
1988+ llvm::SmallVector<InputLayoutDesc> InputLayout;
1989+ for (auto &Attr : P.Bindings .VertexAttributes ) {
1990+ auto FormatOrErr = toFormat (Attr.Format , Attr.Channels );
1991+ if (!FormatOrErr)
1992+ return FormatOrErr.takeError ();
1993+
1994+ InputLayoutDesc Desc = {};
1995+ Desc.Name = Attr.Name ;
1996+ Desc.Fmt = *FormatOrErr;
1997+ Desc.OffsetInBytes = Attr.Offset ;
1998+ InputLayout.push_back (Desc);
1999+ }
2000+
2001+ auto PipelineStateOrErr =
2002+ createPipelineVsPs (" Graphics Pipeline State" , Bindings, InputLayout,
2003+ RTFormats, Format::D32FloatS8Uint, VS , PS );
2004+ if (!PipelineStateOrErr)
2005+ return PipelineStateOrErr.takeError ();
2006+ IS .Pipeline = std::move (*PipelineStateOrErr);
2007+ } else if (P.isMeshShaderRaster ()) {
2008+ std::optional<ShaderContainer> AS ;
2009+ ShaderContainer MS = {};
2010+ std::optional<ShaderContainer> PS ;
2011+ for (auto &Shader : P.Shaders ) {
2012+ ShaderContainer SC = {};
2013+ SC .EntryPoint = Shader.Entry ;
2014+ SC .Shader = Shader.Shader .get ();
2015+ if (Shader.Stage == Stages::Amplification)
2016+ AS = std::move (SC );
2017+ else if (Shader.Stage == Stages::Mesh)
2018+ MS = std::move (SC );
2019+ else if (Shader.Stage == Stages::Pixel)
2020+ PS = std::move (SC );
2021+ }
2022+
2023+ auto PipelineStateOrErr =
2024+ createPipelineAsMsPs (" Mesh Shader Pipeline State" , Bindings,
2025+ RTFormats, Format::D32FloatS8Uint, AS , MS , PS );
2026+ if (!PipelineStateOrErr)
2027+ return PipelineStateOrErr.takeError ();
2028+ IS .Pipeline = std::move (*PipelineStateOrErr);
2029+ llvm::outs () << " Mesh Shader Pipeline created.\n " ;
2030+ }
18402031
18412032 ColorAttachmentFormatDesc ColorAttachment = {};
18422033 ColorAttachment.Fmt = *FormatOrErr;
@@ -1878,7 +2069,13 @@ class MTLDevice : public offloadtest::Device {
18782069 virtual ~MTLDevice () {};
18792070
18802071private:
1881- void queryCapabilities () {}
2072+ void queryCapabilities () {
2073+ // GPU Family Metal3 (macOS 13+) is where mesh shaders became available.
2074+ const bool MeshShaderSupported =
2075+ Device->supportsFamily (MTL ::GPUFamilyMetal3);
2076+ Caps.insert (std::make_pair (
2077+ " MeshShader" , makeCapability<bool >(" MeshShader" , MeshShaderSupported)));
2078+ }
18822079};
18832080} // namespace
18842081
0 commit comments