diff --git a/include/API/Device.h b/include/API/Device.h index f0b85829b..7f12cbb5d 100644 --- a/include/API/Device.h +++ b/include/API/Device.h @@ -124,6 +124,36 @@ struct TraditionalRasterPipelineCreateDesc { } }; +struct MeshShaderRasterPipelineCreateDesc { + llvm::SmallVector RTFormats; + std::optional DSFormat; + PrimitiveTopology Topology; + + ShaderContainer MS; + std::optional AS; + std::optional PS; + + void setShader(Stages Stage, ShaderContainer &&SC) { + switch (Stage) { + case Stages::Amplification: + AS = std::move(SC); + break; + case Stages::Mesh: + MS = std::move(SC); + break; + case Stages::Pixel: + PS = std::move(SC); + break; + case Stages::Vertex: + case Stages::Hull: + case Stages::Domain: + case Stages::Geometry: + case Stages::Compute: + llvm_unreachable("Not a mesh raster pipeline stage."); + } + } +}; + class PipelineState { public: GPUAPI API; @@ -208,6 +238,11 @@ class Device { llvm::StringRef Name, const BindingsDesc &BindingsDesc, const TraditionalRasterPipelineCreateDesc &Desc) = 0; + virtual llvm::Expected> + createMeshShaderRasterPipeline( + llvm::StringRef Name, const BindingsDesc &BindingsDesc, + const MeshShaderRasterPipelineCreateDesc &Desc) = 0; + virtual llvm::Expected> createFence(llvm::StringRef Name) = 0; diff --git a/lib/API/DX/Device.cpp b/lib/API/DX/Device.cpp index fc1600c6f..e76408596 100644 --- a/lib/API/DX/Device.cpp +++ b/lib/API/DX/Device.cpp @@ -1226,21 +1226,18 @@ class DXDevice : public offloadtest::Device { getDXPrimitiveTopology(Desc.Topology, Desc.PatchControlPoints)); } - llvm::Expected> - createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BndDesc, - llvm::ArrayRef RTFormats, - std::optional DSFormat, - std::optional AS, ShaderContainer MS, - std::optional PS) { - assert(RTFormats.size() <= 8); + llvm::Expected> createMeshShaderRasterPipeline( + llvm::StringRef Name, const BindingsDesc &BindingsDesc, + const MeshShaderRasterPipelineCreateDesc &Desc) override { + assert(Desc.RTFormats.size() <= 8); ComPtr RootSig; - if (auto Err = createRootSignature(Name, BndDesc, MS, + if (auto Err = createRootSignature(Name, BindingsDesc, Desc.MS, /*IsGraphics=*/true, RootSig)) return Err; - const D3D12_SHADER_BYTECODE MSBytecode = {MS.Shader->getBuffer().data(), - MS.Shader->getBuffer().size()}; + const D3D12_SHADER_BYTECODE MSBytecode = { + Desc.MS.Shader->getBuffer().data(), Desc.MS.Shader->getBuffer().size()}; if (MSBytecode.BytecodeLength == 0) return llvm::createStringError( std::errc::invalid_argument, @@ -1248,26 +1245,26 @@ class DXDevice : public offloadtest::Device { // The amplification (task) shader is optional. D3D12_SHADER_BYTECODE ASBytecode = {}; - if (AS) { - assert((*AS).Shader->getBufferSize() > 0 && + if (Desc.AS) { + assert((*Desc.AS).Shader->getBufferSize() > 0 && "The passed task/amplification shader was empty."); - ASBytecode = {(*AS).Shader->getBuffer().data(), - (*AS).Shader->getBuffer().size()}; + ASBytecode = {(*Desc.AS).Shader->getBuffer().data(), + (*Desc.AS).Shader->getBuffer().size()}; } // The pixel shader is optional D3D12_SHADER_BYTECODE PSBytecode = {}; - if (PS) { - assert((*PS).Shader->getBufferSize() > 0 && + if (Desc.PS) { + assert((*Desc.PS).Shader->getBufferSize() > 0 && "The passed pixel shader was empty."); - PSBytecode = {(*PS).Shader->getBuffer().data(), - (*PS).Shader->getBuffer().size()}; + PSBytecode = {(*Desc.PS).Shader->getBuffer().data(), + (*Desc.PS).Shader->getBuffer().size()}; } D3D12_RT_FORMAT_ARRAY RTArray = {}; - RTArray.NumRenderTargets = static_cast(RTFormats.size()); - for (size_t I = 0; I < RTFormats.size(); ++I) - RTArray.RTFormats[I] = getDXGIFormat(RTFormats[I]); + RTArray.NumRenderTargets = static_cast(Desc.RTFormats.size()); + for (size_t I = 0; I < Desc.RTFormats.size(); ++I) + RTArray.RTFormats[I] = getDXGIFormat(Desc.RTFormats[I]); CD3DX12_DEPTH_STENCIL_DESC1 DepthStencil(D3D12_DEFAULT); DepthStencil.DepthEnable = true; @@ -1287,10 +1284,10 @@ class DXDevice : public offloadtest::Device { Stream.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT); Stream.DepthStencilState = DepthStencil; Stream.SampleMask = UINT_MAX; - Stream.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE; + Stream.PrimitiveTopologyType = getDXPrimitiveTopologyType(Desc.Topology); Stream.RTVFormats = RTArray; - if (DSFormat) - Stream.DSVFormat = getDXGIFormat(*DSFormat); + if (Desc.DSFormat) + Stream.DSVFormat = getDXGIFormat(*Desc.DSFormat); Stream.SampleDesc = SampleDesc; const D3D12_PIPELINE_STATE_STREAM_DESC StreamDesc = {sizeof(Stream), @@ -2564,38 +2561,24 @@ class DXDevice : public offloadtest::Device { llvm::outs() << "Traditional Raster Pipeline created.\n"; } else if (P.isMeshShaderRaster()) { - - std::optional AS = {}; - ShaderContainer MS = {}; - std::optional PS = {}; + MeshShaderRasterPipelineCreateDesc PipelineDesc = {}; + PipelineDesc.Topology = P.Bindings.Topology; + PipelineDesc.DSFormat = Format::D32FloatS8Uint; for (auto &Shader : P.Shaders) { - if (Shader.Stage == Stages::Amplification) { - ShaderContainer Container; - Container.EntryPoint = Shader.Entry; - Container.Shader = Shader.Shader.get(); - AS = Container; - } else if (Shader.Stage == Stages::Mesh) { - MS.EntryPoint = Shader.Entry; - MS.Shader = Shader.Shader.get(); - } else if (Shader.Stage == Stages::Pixel) { - ShaderContainer Container; - Container.EntryPoint = Shader.Entry; - Container.Shader = Shader.Shader.get(); - PS = Container; - } + ShaderContainer SC = {}; + SC.EntryPoint = Shader.Entry; + SC.Shader = Shader.Shader.get(); + PipelineDesc.setShader(Shader.Stage, std::move(SC)); } auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format, P.Bindings.RTargetBufferPtr->Channels); if (!FormatOrErr) return FormatOrErr.takeError(); + PipelineDesc.RTFormats.push_back(*FormatOrErr); - llvm::SmallVector RTFormats; - RTFormats.push_back(*FormatOrErr); - - auto PipelineStateOrErr = - createPipelineAsMsPs("Mesh Shader Pipeline State", BndDesc, - RTFormats, Format::D32FloatS8Uint, AS, MS, PS); + auto PipelineStateOrErr = createMeshShaderRasterPipeline( + "Mesh Shader Pipeline State", BndDesc, PipelineDesc); if (!PipelineStateOrErr) return PipelineStateOrErr.takeError(); diff --git a/lib/API/MTL/MTLDevice.cpp b/lib/API/MTL/MTLDevice.cpp index 1303e31eb..e28b6e1b5 100644 --- a/lib/API/MTL/MTLDevice.cpp +++ b/lib/API/MTL/MTLDevice.cpp @@ -1804,12 +1804,9 @@ class MTLDevice : public offloadtest::Device { DSState, MTL::CullModeNone); } - llvm::Expected> - createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc, - llvm::ArrayRef RTFormats, - std::optional DSFormat, - std::optional AS, ShaderContainer MS, - std::optional PS) { + llvm::Expected> createMeshShaderRasterPipeline( + llvm::StringRef Name, const BindingsDesc &BindingsDesc, + const MeshShaderRasterPipelineCreateDesc &Desc) override { IRRootSignaturePtr RootSig; std::unique_ptr ArgBuffer; if (auto Err = createRootSignature(BindingsDesc, /*IsGraphics=*/true, @@ -1846,54 +1843,56 @@ class MTLDevice : public offloadtest::Device { MetalIR MSIR; MTLPtr MSLib; MTLPtr MSFn; - if (auto Err = compileStage(Stages::Mesh, MS, "mesh", MSIR, MSLib, MSFn)) + if (auto Err = + compileStage(Stages::Mesh, Desc.MS, "mesh", MSIR, MSLib, MSFn)) return Err; MetalIR ASIR; MTLPtr ASLib; MTLPtr ASFn; - if (AS) { - if (auto Err = compileStage(Stages::Amplification, *AS, "amplification", - ASIR, ASLib, ASFn)) + if (Desc.AS) { + if (auto Err = compileStage(Stages::Amplification, *Desc.AS, + "amplification", ASIR, ASLib, ASFn)) return Err; } MetalIR PSIR; MTLPtr PSLib; MTLPtr PSFn; - if (PS) { - if (auto Err = - compileStage(Stages::Pixel, *PS, "fragment", PSIR, PSLib, PSFn)) + if (Desc.PS) { + if (auto Err = compileStage(Stages::Pixel, *Desc.PS, "fragment", PSIR, + PSLib, PSFn)) return Err; } - MTL::MeshRenderPipelineDescriptor *Desc = + MTL::MeshRenderPipelineDescriptor *MSPDesc = MTL::MeshRenderPipelineDescriptor::alloc()->init(); - auto DescScope = llvm::scope_exit([&] { Desc->release(); }); + auto DescScope = llvm::scope_exit([&] { MSPDesc->release(); }); - Desc->setMeshFunction(MSFn.get()); + MSPDesc->setMeshFunction(MSFn.get()); if (ASFn) - Desc->setObjectFunction(ASFn.get()); + MSPDesc->setObjectFunction(ASFn.get()); if (PSFn) - Desc->setFragmentFunction(PSFn.get()); + MSPDesc->setFragmentFunction(PSFn.get()); - for (size_t I = 0; I < RTFormats.size(); ++I) { + for (size_t I = 0; I < Desc.RTFormats.size(); ++I) { MTL::RenderPipelineColorAttachmentDescriptor *RPCA = MTL::RenderPipelineColorAttachmentDescriptor::alloc()->init(); - RPCA->setPixelFormat(getMetalPixelFormat(RTFormats[I])); - Desc->colorAttachments()->setObject(RPCA, I); + RPCA->setPixelFormat(getMetalPixelFormat(Desc.RTFormats[I])); + MSPDesc->colorAttachments()->setObject(RPCA, I); RPCA->release(); } - if (DSFormat) { - const MTL::PixelFormat DSPixelFormat = getMetalPixelFormat(*DSFormat); - Desc->setDepthAttachmentPixelFormat(DSPixelFormat); - if (isStencilFormat(*DSFormat)) - Desc->setStencilAttachmentPixelFormat(DSPixelFormat); + if (Desc.DSFormat) { + const MTL::PixelFormat DSPixelFormat = + getMetalPixelFormat(*Desc.DSFormat); + MSPDesc->setDepthAttachmentPixelFormat(DSPixelFormat); + if (isStencilFormat(*Desc.DSFormat)) + MSPDesc->setStencilAttachmentPixelFormat(DSPixelFormat); } MTL::RenderPipelineState *PSO = Device->newRenderPipelineState( - Desc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error); + MSPDesc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error); if (Error) return toError(Error); @@ -1918,7 +1917,7 @@ class MTLDevice : public offloadtest::Device { } MTL::Size ObjectTGSize(1, 1, 1); - if (AS) { + if (Desc.AS) { IRVersionedASInfo ASInfo; if (IRShaderReflectionCopyAmplificationInfo( ASIR.Reflection.get(), IRReflectionVersion_1_0, &ASInfo)) { @@ -2021,24 +2020,19 @@ class MTLDevice : public offloadtest::Device { return PipelineStateOrErr.takeError(); IS.Pipeline = std::move(*PipelineStateOrErr); } else if (P.isMeshShaderRaster()) { - std::optional AS; - ShaderContainer MS = {}; - std::optional PS; + MeshShaderRasterPipelineCreateDesc PipelineDesc = {}; + PipelineDesc.Topology = P.Bindings.Topology; + PipelineDesc.DSFormat = Format::D32FloatS8Uint; + PipelineDesc.RTFormats = RTFormats; for (auto &Shader : P.Shaders) { ShaderContainer SC = {}; SC.EntryPoint = Shader.Entry; SC.Shader = Shader.Shader.get(); - if (Shader.Stage == Stages::Amplification) - AS = std::move(SC); - else if (Shader.Stage == Stages::Mesh) - MS = std::move(SC); - else if (Shader.Stage == Stages::Pixel) - PS = std::move(SC); + PipelineDesc.setShader(Shader.Stage, std::move(SC)); } - auto PipelineStateOrErr = - createPipelineAsMsPs("Mesh Shader Pipeline State", Bindings, - RTFormats, Format::D32FloatS8Uint, AS, MS, PS); + auto PipelineStateOrErr = createMeshShaderRasterPipeline( + "Mesh Shader Pipeline State", Bindings, PipelineDesc); if (!PipelineStateOrErr) return PipelineStateOrErr.takeError(); IS.Pipeline = std::move(*PipelineStateOrErr); diff --git a/lib/API/VK/Device.cpp b/lib/API/VK/Device.cpp index fcc1d7c08..9364bd62e 100644 --- a/lib/API/VK/Device.cpp +++ b/lib/API/VK/Device.cpp @@ -1821,13 +1821,10 @@ class VulkanDevice : public offloadtest::Device { Name, Device, Pipeline, PipelineLayout, std::move(SetLayouts)); } - llvm::Expected> - createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc, - llvm::ArrayRef RTFormats, - std::optional DSFormat, - std::optional AS, ShaderContainer MS, - std::optional PS) /*override*/ { - assert(RTFormats.size() <= 8); + llvm::Expected> createMeshShaderRasterPipeline( + llvm::StringRef Name, const BindingsDesc &BindingsDesc, + const MeshShaderRasterPipelineCreateDesc &Desc) override { + assert(Desc.RTFormats.size() <= 8); VkShaderStageFlags GraphicsFlags = VK_SHADER_STAGE_MESH_BIT_EXT; llvm::SmallVector ShaderStages; @@ -1841,12 +1838,12 @@ class VulkanDevice : public offloadtest::Device { llvm::SmallVector MSSpecData; VkSpecializationInfo MSSpecInfo = {}; { - if (auto Err = parseSpecializationConstants(MS.SpecializationConstants, - MSSpecEntries, MSSpecData, - MSSpecInfo)) + if (auto Err = parseSpecializationConstants( + Desc.MS.SpecializationConstants, MSSpecEntries, MSSpecData, + MSSpecInfo)) return Err; - auto MSModOrErr = createShaderModule(MS.Shader, "mesh"); + auto MSModOrErr = createShaderModule(Desc.MS.Shader, "mesh"); if (!MSModOrErr) return MSModOrErr.takeError(); @@ -1854,22 +1851,22 @@ class VulkanDevice : public offloadtest::Device { ShaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; ShaderStage.stage = VK_SHADER_STAGE_MESH_BIT_EXT; ShaderStage.module = *MSModOrErr; - ShaderStage.pName = MS.EntryPoint.c_str(); + ShaderStage.pName = Desc.MS.EntryPoint.c_str(); ShaderStage.pSpecializationInfo = - MS.SpecializationConstants.empty() ? nullptr : &MSSpecInfo; + Desc.MS.SpecializationConstants.empty() ? nullptr : &MSSpecInfo; ShaderStages.push_back(ShaderStage); } llvm::SmallVector ASSpecEntries; llvm::SmallVector ASSpecData; VkSpecializationInfo ASSpecInfo = {}; - if (AS) { - if (auto Err = parseSpecializationConstants((*AS).SpecializationConstants, - ASSpecEntries, ASSpecData, - ASSpecInfo)) + if (Desc.AS) { + if (auto Err = parseSpecializationConstants( + (*Desc.AS).SpecializationConstants, ASSpecEntries, ASSpecData, + ASSpecInfo)) return Err; - auto ASModOrErr = createShaderModule((*AS).Shader, "task"); + auto ASModOrErr = createShaderModule((*Desc.AS).Shader, "task"); if (!ASModOrErr) return ASModOrErr.takeError(); @@ -1879,22 +1876,22 @@ class VulkanDevice : public offloadtest::Device { ShaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; ShaderStage.stage = VK_SHADER_STAGE_TASK_BIT_EXT; ShaderStage.module = *ASModOrErr; - ShaderStage.pName = (*AS).EntryPoint.c_str(); + ShaderStage.pName = (*Desc.AS).EntryPoint.c_str(); ShaderStage.pSpecializationInfo = - (*AS).SpecializationConstants.empty() ? nullptr : &ASSpecInfo; + (*Desc.AS).SpecializationConstants.empty() ? nullptr : &ASSpecInfo; ShaderStages.push_back(ShaderStage); } llvm::SmallVector PSSpecEntries; llvm::SmallVector PSSpecData; VkSpecializationInfo PSSpecInfo = {}; - if (PS) { - if (auto Err = parseSpecializationConstants((*PS).SpecializationConstants, - PSSpecEntries, PSSpecData, - PSSpecInfo)) + if (Desc.PS) { + if (auto Err = parseSpecializationConstants( + (*Desc.PS).SpecializationConstants, PSSpecEntries, PSSpecData, + PSSpecInfo)) return Err; - auto PSModOrErr = createShaderModule((*PS).Shader, "pixel"); + auto PSModOrErr = createShaderModule((*Desc.PS).Shader, "pixel"); if (!PSModOrErr) return PSModOrErr.takeError(); @@ -1904,23 +1901,23 @@ class VulkanDevice : public offloadtest::Device { ShaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; ShaderStage.stage = VK_SHADER_STAGE_FRAGMENT_BIT; ShaderStage.module = *PSModOrErr; - ShaderStage.pName = (*PS).EntryPoint.c_str(); + ShaderStage.pName = (*Desc.PS).EntryPoint.c_str(); ShaderStage.pSpecializationInfo = - (*PS).SpecializationConstants.empty() ? nullptr : &PSSpecInfo; + (*Desc.PS).SpecializationConstants.empty() ? nullptr : &PSSpecInfo; ShaderStages.push_back(ShaderStage); } // Build a RenderPassDesc from the PSO's RT/DS formats. RenderPassDesc PassDesc; - PassDesc.ColorAttachments.reserve(RTFormats.size()); - for (const Format F : RTFormats) { + PassDesc.ColorAttachments.reserve(Desc.RTFormats.size()); + for (const Format F : Desc.RTFormats) { ColorAttachmentFormatDesc CA = {}; CA.Fmt = F; PassDesc.ColorAttachments.push_back(CA); } - if (DSFormat) { + if (Desc.DSFormat) { DepthStencilAttachmentFormatDesc DS = {}; - DS.Fmt = *DSFormat; + DS.Fmt = *Desc.DSFormat; PassDesc.DepthStencil = DS; } @@ -1978,7 +1975,7 @@ class VulkanDevice : public offloadtest::Device { DepthStencilCI.front = DepthStencilCI.back; llvm::SmallVector BlendAttachments( - RTFormats.size()); + Desc.RTFormats.size()); for (auto &BA : BlendAttachments) BA.colorWriteMask = 0xf; VkPipelineColorBlendStateCreateInfo BlendCI = {}; @@ -3603,37 +3600,24 @@ class VulkanDevice : public offloadtest::Device { State.Pipeline = std::move(*PipelineStateOrErr); llvm::outs() << "Graphics Pipeline created.\n"; } else if (P.isMeshShaderRaster()) { - std::optional AS = {}; - ShaderContainer MS = {}; - std::optional PS = {}; + MeshShaderRasterPipelineCreateDesc PipelineDesc = {}; + PipelineDesc.Topology = P.Bindings.Topology; + PipelineDesc.DSFormat = Format::D32FloatS8Uint; for (auto &Shader : P.Shaders) { - if (Shader.Stage == Stages::Amplification) { - ShaderContainer Container; - Container.EntryPoint = Shader.Entry; - Container.Shader = Shader.Shader.get(); - AS = Container; - } else if (Shader.Stage == Stages::Mesh) { - MS.EntryPoint = Shader.Entry; - MS.Shader = Shader.Shader.get(); - } else if (Shader.Stage == Stages::Pixel) { - ShaderContainer Container; - Container.EntryPoint = Shader.Entry; - Container.Shader = Shader.Shader.get(); - PS = Container; - } + ShaderContainer SC = {}; + SC.EntryPoint = Shader.Entry; + SC.Shader = Shader.Shader.get(); + PipelineDesc.setShader(Shader.Stage, std::move(SC)); } auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format, P.Bindings.RTargetBufferPtr->Channels); if (!FormatOrErr) return FormatOrErr.takeError(); + PipelineDesc.RTFormats.push_back(*FormatOrErr); - llvm::SmallVector RTFormats; - RTFormats.push_back(*FormatOrErr); - - auto PipelineStateOrErr = - createPipelineAsMsPs("Mesh Shader Pipeline State", BindingsDesc, - RTFormats, Format::D32FloatS8Uint, AS, MS, PS); + auto PipelineStateOrErr = createMeshShaderRasterPipeline( + "Mesh Shader Pipeline State", BindingsDesc, PipelineDesc); if (!PipelineStateOrErr) return PipelineStateOrErr.takeError(); State.Pipeline = std::move(*PipelineStateOrErr);