Skip to content

Commit a4748cc

Browse files
Add the mesh shader pipeline creation function to the Device base class. (#1246)
The mesh shader pipeline creation function was not part of the Device base class. This PR adds it to the base class, and also introduces a create description similar to the traditional raster pipeline creation function. Additionally support for specifying the topology has also been added. This is only needed in DX12 since both Vulkan and Metal just extract the topology type from the shader byte code. Technically DX12 only cares about the topology type (triangles vs lines, instead triangle list vs triangle strip), but I think it is overkill to introduce a new enum for topology types to the .yaml file if we can just use a topology type instead.
1 parent 9bc0142 commit a4748cc

4 files changed

Lines changed: 139 additions & 143 deletions

File tree

include/API/Device.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,36 @@ struct TraditionalRasterPipelineCreateDesc {
125125
}
126126
};
127127

128+
struct MeshShaderRasterPipelineCreateDesc {
129+
llvm::SmallVector<Format> RTFormats;
130+
std::optional<Format> DSFormat;
131+
PrimitiveTopology Topology;
132+
133+
ShaderContainer MS;
134+
std::optional<ShaderContainer> AS;
135+
std::optional<ShaderContainer> PS;
136+
137+
void setShader(Stages Stage, ShaderContainer &&SC) {
138+
switch (Stage) {
139+
case Stages::Amplification:
140+
AS = std::move(SC);
141+
break;
142+
case Stages::Mesh:
143+
MS = std::move(SC);
144+
break;
145+
case Stages::Pixel:
146+
PS = std::move(SC);
147+
break;
148+
case Stages::Vertex:
149+
case Stages::Hull:
150+
case Stages::Domain:
151+
case Stages::Geometry:
152+
case Stages::Compute:
153+
llvm_unreachable("Not a mesh raster pipeline stage.");
154+
}
155+
}
156+
};
157+
128158
class PipelineState {
129159
public:
130160
GPUAPI API;
@@ -211,6 +241,11 @@ class Device {
211241
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
212242
const TraditionalRasterPipelineCreateDesc &Desc) = 0;
213243

244+
virtual llvm::Expected<std::unique_ptr<PipelineState>>
245+
createMeshShaderRasterPipeline(
246+
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
247+
const MeshShaderRasterPipelineCreateDesc &Desc) = 0;
248+
214249
virtual llvm::Expected<std::unique_ptr<Fence>>
215250
createFence(llvm::StringRef Name) = 0;
216251

lib/API/DX/Device.cpp

Lines changed: 31 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,48 +1271,45 @@ class DXDevice : public offloadtest::Device {
12711271
getDXPrimitiveTopology(Desc.Topology, Desc.PatchControlPoints));
12721272
}
12731273

1274-
llvm::Expected<std::unique_ptr<PipelineState>>
1275-
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BndDesc,
1276-
llvm::ArrayRef<Format> RTFormats,
1277-
std::optional<Format> DSFormat,
1278-
std::optional<ShaderContainer> AS, ShaderContainer MS,
1279-
std::optional<ShaderContainer> PS) {
1280-
assert(RTFormats.size() <= 8);
1274+
llvm::Expected<std::unique_ptr<PipelineState>> createMeshShaderRasterPipeline(
1275+
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1276+
const MeshShaderRasterPipelineCreateDesc &Desc) override {
1277+
assert(Desc.RTFormats.size() <= 8);
12811278

12821279
ComPtr<ID3D12RootSignature> RootSig;
1283-
if (auto Err = createRootSignature(Name, BndDesc, MS,
1280+
if (auto Err = createRootSignature(Name, BindingsDesc, Desc.MS,
12841281
/*IsGraphics=*/true, RootSig))
12851282
return Err;
12861283

1287-
const D3D12_SHADER_BYTECODE MSBytecode = {MS.Shader->getBuffer().data(),
1288-
MS.Shader->getBuffer().size()};
1284+
const D3D12_SHADER_BYTECODE MSBytecode = {
1285+
Desc.MS.Shader->getBuffer().data(), Desc.MS.Shader->getBuffer().size()};
12891286
if (MSBytecode.BytecodeLength == 0)
12901287
return llvm::createStringError(
12911288
std::errc::invalid_argument,
12921289
"Mesh shader pipeline requires a mesh shader.");
12931290

12941291
// The amplification (task) shader is optional.
12951292
D3D12_SHADER_BYTECODE ASBytecode = {};
1296-
if (AS) {
1297-
assert((*AS).Shader->getBufferSize() > 0 &&
1293+
if (Desc.AS) {
1294+
assert((*Desc.AS).Shader->getBufferSize() > 0 &&
12981295
"The passed task/amplification shader was empty.");
1299-
ASBytecode = {(*AS).Shader->getBuffer().data(),
1300-
(*AS).Shader->getBuffer().size()};
1296+
ASBytecode = {(*Desc.AS).Shader->getBuffer().data(),
1297+
(*Desc.AS).Shader->getBuffer().size()};
13011298
}
13021299

13031300
// The pixel shader is optional
13041301
D3D12_SHADER_BYTECODE PSBytecode = {};
1305-
if (PS) {
1306-
assert((*PS).Shader->getBufferSize() > 0 &&
1302+
if (Desc.PS) {
1303+
assert((*Desc.PS).Shader->getBufferSize() > 0 &&
13071304
"The passed pixel shader was empty.");
1308-
PSBytecode = {(*PS).Shader->getBuffer().data(),
1309-
(*PS).Shader->getBuffer().size()};
1305+
PSBytecode = {(*Desc.PS).Shader->getBuffer().data(),
1306+
(*Desc.PS).Shader->getBuffer().size()};
13101307
}
13111308

13121309
D3D12_RT_FORMAT_ARRAY RTArray = {};
1313-
RTArray.NumRenderTargets = static_cast<UINT>(RTFormats.size());
1314-
for (size_t I = 0; I < RTFormats.size(); ++I)
1315-
RTArray.RTFormats[I] = getDXGIFormat(RTFormats[I]);
1310+
RTArray.NumRenderTargets = static_cast<UINT>(Desc.RTFormats.size());
1311+
for (size_t I = 0; I < Desc.RTFormats.size(); ++I)
1312+
RTArray.RTFormats[I] = getDXGIFormat(Desc.RTFormats[I]);
13161313

13171314
CD3DX12_DEPTH_STENCIL_DESC1 DepthStencil(D3D12_DEFAULT);
13181315
DepthStencil.DepthEnable = true;
@@ -1332,10 +1329,10 @@ class DXDevice : public offloadtest::Device {
13321329
Stream.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT);
13331330
Stream.DepthStencilState = DepthStencil;
13341331
Stream.SampleMask = UINT_MAX;
1335-
Stream.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
1332+
Stream.PrimitiveTopologyType = getDXPrimitiveTopologyType(Desc.Topology);
13361333
Stream.RTVFormats = RTArray;
1337-
if (DSFormat)
1338-
Stream.DSVFormat = getDXGIFormat(*DSFormat);
1334+
if (Desc.DSFormat)
1335+
Stream.DSVFormat = getDXGIFormat(*Desc.DSFormat);
13391336
Stream.SampleDesc = SampleDesc;
13401337

13411338
const D3D12_PIPELINE_STATE_STREAM_DESC StreamDesc = {sizeof(Stream),
@@ -2763,38 +2760,24 @@ class DXDevice : public offloadtest::Device {
27632760
llvm::outs() << "Traditional Raster Pipeline created.\n";
27642761

27652762
} else if (P.isMeshShaderRaster()) {
2766-
2767-
std::optional<ShaderContainer> AS = {};
2768-
ShaderContainer MS = {};
2769-
std::optional<ShaderContainer> PS = {};
2763+
MeshShaderRasterPipelineCreateDesc PipelineDesc = {};
2764+
PipelineDesc.Topology = P.Bindings.Topology;
2765+
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
27702766
for (auto &Shader : P.Shaders) {
2771-
if (Shader.Stage == Stages::Amplification) {
2772-
ShaderContainer Container;
2773-
Container.EntryPoint = Shader.Entry;
2774-
Container.Shader = Shader.Shader.get();
2775-
AS = Container;
2776-
} else if (Shader.Stage == Stages::Mesh) {
2777-
MS.EntryPoint = Shader.Entry;
2778-
MS.Shader = Shader.Shader.get();
2779-
} else if (Shader.Stage == Stages::Pixel) {
2780-
ShaderContainer Container;
2781-
Container.EntryPoint = Shader.Entry;
2782-
Container.Shader = Shader.Shader.get();
2783-
PS = Container;
2784-
}
2767+
ShaderContainer SC = {};
2768+
SC.EntryPoint = Shader.Entry;
2769+
SC.Shader = Shader.Shader.get();
2770+
PipelineDesc.setShader(Shader.Stage, std::move(SC));
27852771
}
27862772

27872773
auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format,
27882774
P.Bindings.RTargetBufferPtr->Channels);
27892775
if (!FormatOrErr)
27902776
return FormatOrErr.takeError();
2777+
PipelineDesc.RTFormats.push_back(*FormatOrErr);
27912778

2792-
llvm::SmallVector<Format> RTFormats;
2793-
RTFormats.push_back(*FormatOrErr);
2794-
2795-
auto PipelineStateOrErr =
2796-
createPipelineAsMsPs("Mesh Shader Pipeline State", BndDesc,
2797-
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
2779+
auto PipelineStateOrErr = createMeshShaderRasterPipeline(
2780+
"Mesh Shader Pipeline State", BndDesc, PipelineDesc);
27982781

27992782
if (!PipelineStateOrErr)
28002783
return PipelineStateOrErr.takeError();

lib/API/MTL/MTLDevice.cpp

Lines changed: 34 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,12 +1824,9 @@ class MTLDevice : public offloadtest::Device {
18241824
DSState, MTL::CullModeNone);
18251825
}
18261826

1827-
llvm::Expected<std::unique_ptr<PipelineState>>
1828-
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1829-
llvm::ArrayRef<Format> RTFormats,
1830-
std::optional<Format> DSFormat,
1831-
std::optional<ShaderContainer> AS, ShaderContainer MS,
1832-
std::optional<ShaderContainer> PS) {
1827+
llvm::Expected<std::unique_ptr<PipelineState>> createMeshShaderRasterPipeline(
1828+
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
1829+
const MeshShaderRasterPipelineCreateDesc &Desc) override {
18331830
IRRootSignaturePtr RootSig;
18341831
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
18351832
if (auto Err = createRootSignature(BindingsDesc, /*IsGraphics=*/true,
@@ -1866,54 +1863,56 @@ class MTLDevice : public offloadtest::Device {
18661863
MetalIR MSIR;
18671864
MTLPtr<MTL::Library> MSLib;
18681865
MTLPtr<MTL::Function> MSFn;
1869-
if (auto Err = compileStage(Stages::Mesh, MS, "mesh", MSIR, MSLib, MSFn))
1866+
if (auto Err =
1867+
compileStage(Stages::Mesh, Desc.MS, "mesh", MSIR, MSLib, MSFn))
18701868
return Err;
18711869

18721870
MetalIR ASIR;
18731871
MTLPtr<MTL::Library> ASLib;
18741872
MTLPtr<MTL::Function> ASFn;
1875-
if (AS) {
1876-
if (auto Err = compileStage(Stages::Amplification, *AS, "amplification",
1877-
ASIR, ASLib, ASFn))
1873+
if (Desc.AS) {
1874+
if (auto Err = compileStage(Stages::Amplification, *Desc.AS,
1875+
"amplification", ASIR, ASLib, ASFn))
18781876
return Err;
18791877
}
18801878

18811879
MetalIR PSIR;
18821880
MTLPtr<MTL::Library> PSLib;
18831881
MTLPtr<MTL::Function> PSFn;
1884-
if (PS) {
1885-
if (auto Err =
1886-
compileStage(Stages::Pixel, *PS, "fragment", PSIR, PSLib, PSFn))
1882+
if (Desc.PS) {
1883+
if (auto Err = compileStage(Stages::Pixel, *Desc.PS, "fragment", PSIR,
1884+
PSLib, PSFn))
18871885
return Err;
18881886
}
18891887

1890-
MTL::MeshRenderPipelineDescriptor *Desc =
1888+
MTL::MeshRenderPipelineDescriptor *MSPDesc =
18911889
MTL::MeshRenderPipelineDescriptor::alloc()->init();
1892-
auto DescScope = llvm::scope_exit([&] { Desc->release(); });
1890+
auto DescScope = llvm::scope_exit([&] { MSPDesc->release(); });
18931891

1894-
Desc->setMeshFunction(MSFn.get());
1892+
MSPDesc->setMeshFunction(MSFn.get());
18951893
if (ASFn)
1896-
Desc->setObjectFunction(ASFn.get());
1894+
MSPDesc->setObjectFunction(ASFn.get());
18971895
if (PSFn)
1898-
Desc->setFragmentFunction(PSFn.get());
1896+
MSPDesc->setFragmentFunction(PSFn.get());
18991897

1900-
for (size_t I = 0; I < RTFormats.size(); ++I) {
1898+
for (size_t I = 0; I < Desc.RTFormats.size(); ++I) {
19011899
MTL::RenderPipelineColorAttachmentDescriptor *RPCA =
19021900
MTL::RenderPipelineColorAttachmentDescriptor::alloc()->init();
1903-
RPCA->setPixelFormat(getMetalPixelFormat(RTFormats[I]));
1904-
Desc->colorAttachments()->setObject(RPCA, I);
1901+
RPCA->setPixelFormat(getMetalPixelFormat(Desc.RTFormats[I]));
1902+
MSPDesc->colorAttachments()->setObject(RPCA, I);
19051903
RPCA->release();
19061904
}
19071905

1908-
if (DSFormat) {
1909-
const MTL::PixelFormat DSPixelFormat = getMetalPixelFormat(*DSFormat);
1910-
Desc->setDepthAttachmentPixelFormat(DSPixelFormat);
1911-
if (isStencilFormat(*DSFormat))
1912-
Desc->setStencilAttachmentPixelFormat(DSPixelFormat);
1906+
if (Desc.DSFormat) {
1907+
const MTL::PixelFormat DSPixelFormat =
1908+
getMetalPixelFormat(*Desc.DSFormat);
1909+
MSPDesc->setDepthAttachmentPixelFormat(DSPixelFormat);
1910+
if (isStencilFormat(*Desc.DSFormat))
1911+
MSPDesc->setStencilAttachmentPixelFormat(DSPixelFormat);
19131912
}
19141913

19151914
MTL::RenderPipelineState *PSO = Device->newRenderPipelineState(
1916-
Desc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
1915+
MSPDesc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
19171916
if (Error)
19181917
return toError(Error);
19191918

@@ -1938,7 +1937,7 @@ class MTLDevice : public offloadtest::Device {
19381937
}
19391938

19401939
MTL::Size ObjectTGSize(1, 1, 1);
1941-
if (AS) {
1940+
if (Desc.AS) {
19421941
IRVersionedASInfo ASInfo;
19431942
if (IRShaderReflectionCopyAmplificationInfo(
19441943
ASIR.Reflection.get(), IRReflectionVersion_1_0, &ASInfo)) {
@@ -2199,24 +2198,19 @@ class MTLDevice : public offloadtest::Device {
21992198
return PipelineStateOrErr.takeError();
22002199
IS.Pipeline = std::move(*PipelineStateOrErr);
22012200
} else if (P.isMeshShaderRaster()) {
2202-
std::optional<ShaderContainer> AS;
2203-
ShaderContainer MS = {};
2204-
std::optional<ShaderContainer> PS;
2201+
MeshShaderRasterPipelineCreateDesc PipelineDesc = {};
2202+
PipelineDesc.Topology = P.Bindings.Topology;
2203+
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
2204+
PipelineDesc.RTFormats = RTFormats;
22052205
for (auto &Shader : P.Shaders) {
22062206
ShaderContainer SC = {};
22072207
SC.EntryPoint = Shader.Entry;
22082208
SC.Shader = Shader.Shader.get();
2209-
if (Shader.Stage == Stages::Amplification)
2210-
AS = std::move(SC);
2211-
else if (Shader.Stage == Stages::Mesh)
2212-
MS = std::move(SC);
2213-
else if (Shader.Stage == Stages::Pixel)
2214-
PS = std::move(SC);
2209+
PipelineDesc.setShader(Shader.Stage, std::move(SC));
22152210
}
22162211

2217-
auto PipelineStateOrErr =
2218-
createPipelineAsMsPs("Mesh Shader Pipeline State", Bindings,
2219-
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
2212+
auto PipelineStateOrErr = createMeshShaderRasterPipeline(
2213+
"Mesh Shader Pipeline State", Bindings, PipelineDesc);
22202214
if (!PipelineStateOrErr)
22212215
return PipelineStateOrErr.takeError();
22222216
IS.Pipeline = std::move(*PipelineStateOrErr);

0 commit comments

Comments
 (0)