Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions include/API/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,36 @@ struct TraditionalRasterPipelineCreateDesc {
}
};

struct MeshShaderRasterPipelineCreateDesc {
llvm::SmallVector<Format> RTFormats;
std::optional<Format> DSFormat;
PrimitiveTopology Topology;

ShaderContainer MS;
std::optional<ShaderContainer> AS;
std::optional<ShaderContainer> PS;

void setShader(Stages Stage, ShaderContainer &&SC) {
switch (Stage) {
case Stages::Amplification:
AS = std::move(SC);
break;
case Stages::Mesh:
MS = std::move(SC);
break;
case Stages::Pixel:
PS = std::move(SC);
break;
case Stages::Vertex:
case Stages::Hull:
case Stages::Domain:
case Stages::Geometry:
case Stages::Compute:
llvm_unreachable("Not a mesh raster pipeline stage.");
}
}
};

class PipelineState {
public:
GPUAPI API;
Expand Down Expand Up @@ -208,6 +238,11 @@ class Device {
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
const TraditionalRasterPipelineCreateDesc &Desc) = 0;

virtual llvm::Expected<std::unique_ptr<PipelineState>>
createMeshShaderRasterPipeline(
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
const MeshShaderRasterPipelineCreateDesc &Desc) = 0;

virtual llvm::Expected<std::unique_ptr<Fence>>
createFence(llvm::StringRef Name) = 0;

Expand Down
79 changes: 31 additions & 48 deletions lib/API/DX/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1226,48 +1226,45 @@ class DXDevice : public offloadtest::Device {
getDXPrimitiveTopology(Desc.Topology, Desc.PatchControlPoints));
}

llvm::Expected<std::unique_ptr<PipelineState>>
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BndDesc,
llvm::ArrayRef<Format> RTFormats,
std::optional<Format> DSFormat,
std::optional<ShaderContainer> AS, ShaderContainer MS,
std::optional<ShaderContainer> PS) {
assert(RTFormats.size() <= 8);
llvm::Expected<std::unique_ptr<PipelineState>> createMeshShaderRasterPipeline(
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
const MeshShaderRasterPipelineCreateDesc &Desc) override {
assert(Desc.RTFormats.size() <= 8);

ComPtr<ID3D12RootSignature> RootSig;
if (auto Err = createRootSignature(Name, BndDesc, MS,
if (auto Err = createRootSignature(Name, BindingsDesc, Desc.MS,
/*IsGraphics=*/true, RootSig))
return Err;

const D3D12_SHADER_BYTECODE MSBytecode = {MS.Shader->getBuffer().data(),
MS.Shader->getBuffer().size()};
const D3D12_SHADER_BYTECODE MSBytecode = {
Desc.MS.Shader->getBuffer().data(), Desc.MS.Shader->getBuffer().size()};
if (MSBytecode.BytecodeLength == 0)
return llvm::createStringError(
std::errc::invalid_argument,
"Mesh shader pipeline requires a mesh shader.");

// The amplification (task) shader is optional.
D3D12_SHADER_BYTECODE ASBytecode = {};
if (AS) {
assert((*AS).Shader->getBufferSize() > 0 &&
if (Desc.AS) {
assert((*Desc.AS).Shader->getBufferSize() > 0 &&
"The passed task/amplification shader was empty.");
ASBytecode = {(*AS).Shader->getBuffer().data(),
(*AS).Shader->getBuffer().size()};
ASBytecode = {(*Desc.AS).Shader->getBuffer().data(),
(*Desc.AS).Shader->getBuffer().size()};
}

// The pixel shader is optional
D3D12_SHADER_BYTECODE PSBytecode = {};
if (PS) {
assert((*PS).Shader->getBufferSize() > 0 &&
if (Desc.PS) {
assert((*Desc.PS).Shader->getBufferSize() > 0 &&
"The passed pixel shader was empty.");
PSBytecode = {(*PS).Shader->getBuffer().data(),
(*PS).Shader->getBuffer().size()};
PSBytecode = {(*Desc.PS).Shader->getBuffer().data(),
(*Desc.PS).Shader->getBuffer().size()};
}

D3D12_RT_FORMAT_ARRAY RTArray = {};
RTArray.NumRenderTargets = static_cast<UINT>(RTFormats.size());
for (size_t I = 0; I < RTFormats.size(); ++I)
RTArray.RTFormats[I] = getDXGIFormat(RTFormats[I]);
RTArray.NumRenderTargets = static_cast<UINT>(Desc.RTFormats.size());
for (size_t I = 0; I < Desc.RTFormats.size(); ++I)
RTArray.RTFormats[I] = getDXGIFormat(Desc.RTFormats[I]);

CD3DX12_DEPTH_STENCIL_DESC1 DepthStencil(D3D12_DEFAULT);
DepthStencil.DepthEnable = true;
Expand All @@ -1287,10 +1284,10 @@ class DXDevice : public offloadtest::Device {
Stream.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT);
Stream.DepthStencilState = DepthStencil;
Stream.SampleMask = UINT_MAX;
Stream.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
Stream.PrimitiveTopologyType = getDXPrimitiveTopologyType(Desc.Topology);
Stream.RTVFormats = RTArray;
if (DSFormat)
Stream.DSVFormat = getDXGIFormat(*DSFormat);
if (Desc.DSFormat)
Stream.DSVFormat = getDXGIFormat(*Desc.DSFormat);
Stream.SampleDesc = SampleDesc;

const D3D12_PIPELINE_STATE_STREAM_DESC StreamDesc = {sizeof(Stream),
Expand Down Expand Up @@ -2564,38 +2561,24 @@ class DXDevice : public offloadtest::Device {
llvm::outs() << "Traditional Raster Pipeline created.\n";

} else if (P.isMeshShaderRaster()) {

std::optional<ShaderContainer> AS = {};
ShaderContainer MS = {};
std::optional<ShaderContainer> PS = {};
MeshShaderRasterPipelineCreateDesc PipelineDesc = {};
PipelineDesc.Topology = P.Bindings.Topology;
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
for (auto &Shader : P.Shaders) {
if (Shader.Stage == Stages::Amplification) {
ShaderContainer Container;
Container.EntryPoint = Shader.Entry;
Container.Shader = Shader.Shader.get();
AS = Container;
} else if (Shader.Stage == Stages::Mesh) {
MS.EntryPoint = Shader.Entry;
MS.Shader = Shader.Shader.get();
} else if (Shader.Stage == Stages::Pixel) {
ShaderContainer Container;
Container.EntryPoint = Shader.Entry;
Container.Shader = Shader.Shader.get();
PS = Container;
}
ShaderContainer SC = {};
SC.EntryPoint = Shader.Entry;
SC.Shader = Shader.Shader.get();
PipelineDesc.setShader(Shader.Stage, std::move(SC));
}

auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format,
P.Bindings.RTargetBufferPtr->Channels);
if (!FormatOrErr)
return FormatOrErr.takeError();
PipelineDesc.RTFormats.push_back(*FormatOrErr);

llvm::SmallVector<Format> RTFormats;
RTFormats.push_back(*FormatOrErr);

auto PipelineStateOrErr =
createPipelineAsMsPs("Mesh Shader Pipeline State", BndDesc,
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
auto PipelineStateOrErr = createMeshShaderRasterPipeline(
"Mesh Shader Pipeline State", BndDesc, PipelineDesc);

if (!PipelineStateOrErr)
return PipelineStateOrErr.takeError();
Expand Down
74 changes: 34 additions & 40 deletions lib/API/MTL/MTLDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1804,12 +1804,9 @@ class MTLDevice : public offloadtest::Device {
DSState, MTL::CullModeNone);
}

llvm::Expected<std::unique_ptr<PipelineState>>
createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
llvm::ArrayRef<Format> RTFormats,
std::optional<Format> DSFormat,
std::optional<ShaderContainer> AS, ShaderContainer MS,
std::optional<ShaderContainer> PS) {
llvm::Expected<std::unique_ptr<PipelineState>> createMeshShaderRasterPipeline(
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
const MeshShaderRasterPipelineCreateDesc &Desc) override {
IRRootSignaturePtr RootSig;
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
if (auto Err = createRootSignature(BindingsDesc, /*IsGraphics=*/true,
Expand Down Expand Up @@ -1846,54 +1843,56 @@ class MTLDevice : public offloadtest::Device {
MetalIR MSIR;
MTLPtr<MTL::Library> MSLib;
MTLPtr<MTL::Function> MSFn;
if (auto Err = compileStage(Stages::Mesh, MS, "mesh", MSIR, MSLib, MSFn))
if (auto Err =
compileStage(Stages::Mesh, Desc.MS, "mesh", MSIR, MSLib, MSFn))
return Err;

MetalIR ASIR;
MTLPtr<MTL::Library> ASLib;
MTLPtr<MTL::Function> ASFn;
if (AS) {
if (auto Err = compileStage(Stages::Amplification, *AS, "amplification",
ASIR, ASLib, ASFn))
if (Desc.AS) {
if (auto Err = compileStage(Stages::Amplification, *Desc.AS,
"amplification", ASIR, ASLib, ASFn))
return Err;
}

MetalIR PSIR;
MTLPtr<MTL::Library> PSLib;
MTLPtr<MTL::Function> PSFn;
if (PS) {
if (auto Err =
compileStage(Stages::Pixel, *PS, "fragment", PSIR, PSLib, PSFn))
if (Desc.PS) {
if (auto Err = compileStage(Stages::Pixel, *Desc.PS, "fragment", PSIR,
PSLib, PSFn))
return Err;
}

MTL::MeshRenderPipelineDescriptor *Desc =
MTL::MeshRenderPipelineDescriptor *MSPDesc =
MTL::MeshRenderPipelineDescriptor::alloc()->init();
auto DescScope = llvm::scope_exit([&] { Desc->release(); });
auto DescScope = llvm::scope_exit([&] { MSPDesc->release(); });

Desc->setMeshFunction(MSFn.get());
MSPDesc->setMeshFunction(MSFn.get());
if (ASFn)
Desc->setObjectFunction(ASFn.get());
MSPDesc->setObjectFunction(ASFn.get());
if (PSFn)
Desc->setFragmentFunction(PSFn.get());
MSPDesc->setFragmentFunction(PSFn.get());

for (size_t I = 0; I < RTFormats.size(); ++I) {
for (size_t I = 0; I < Desc.RTFormats.size(); ++I) {
MTL::RenderPipelineColorAttachmentDescriptor *RPCA =
MTL::RenderPipelineColorAttachmentDescriptor::alloc()->init();
RPCA->setPixelFormat(getMetalPixelFormat(RTFormats[I]));
Desc->colorAttachments()->setObject(RPCA, I);
RPCA->setPixelFormat(getMetalPixelFormat(Desc.RTFormats[I]));
MSPDesc->colorAttachments()->setObject(RPCA, I);
RPCA->release();
}

if (DSFormat) {
const MTL::PixelFormat DSPixelFormat = getMetalPixelFormat(*DSFormat);
Desc->setDepthAttachmentPixelFormat(DSPixelFormat);
if (isStencilFormat(*DSFormat))
Desc->setStencilAttachmentPixelFormat(DSPixelFormat);
if (Desc.DSFormat) {
const MTL::PixelFormat DSPixelFormat =
getMetalPixelFormat(*Desc.DSFormat);
MSPDesc->setDepthAttachmentPixelFormat(DSPixelFormat);
if (isStencilFormat(*Desc.DSFormat))
MSPDesc->setStencilAttachmentPixelFormat(DSPixelFormat);
}

MTL::RenderPipelineState *PSO = Device->newRenderPipelineState(
Desc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
MSPDesc, MTL::PipelineOptionNone, /*reflection=*/nullptr, &Error);
if (Error)
return toError(Error);

Expand All @@ -1918,7 +1917,7 @@ class MTLDevice : public offloadtest::Device {
}

MTL::Size ObjectTGSize(1, 1, 1);
if (AS) {
if (Desc.AS) {
IRVersionedASInfo ASInfo;
if (IRShaderReflectionCopyAmplificationInfo(
ASIR.Reflection.get(), IRReflectionVersion_1_0, &ASInfo)) {
Expand Down Expand Up @@ -2021,24 +2020,19 @@ class MTLDevice : public offloadtest::Device {
return PipelineStateOrErr.takeError();
IS.Pipeline = std::move(*PipelineStateOrErr);
} else if (P.isMeshShaderRaster()) {
std::optional<ShaderContainer> AS;
ShaderContainer MS = {};
std::optional<ShaderContainer> PS;
MeshShaderRasterPipelineCreateDesc PipelineDesc = {};
PipelineDesc.Topology = P.Bindings.Topology;
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
PipelineDesc.RTFormats = RTFormats;
for (auto &Shader : P.Shaders) {
ShaderContainer SC = {};
SC.EntryPoint = Shader.Entry;
SC.Shader = Shader.Shader.get();
if (Shader.Stage == Stages::Amplification)
AS = std::move(SC);
else if (Shader.Stage == Stages::Mesh)
MS = std::move(SC);
else if (Shader.Stage == Stages::Pixel)
PS = std::move(SC);
PipelineDesc.setShader(Shader.Stage, std::move(SC));
}

auto PipelineStateOrErr =
createPipelineAsMsPs("Mesh Shader Pipeline State", Bindings,
RTFormats, Format::D32FloatS8Uint, AS, MS, PS);
auto PipelineStateOrErr = createMeshShaderRasterPipeline(
"Mesh Shader Pipeline State", Bindings, PipelineDesc);
if (!PipelineStateOrErr)
return PipelineStateOrErr.takeError();
IS.Pipeline = std::move(*PipelineStateOrErr);
Expand Down
Loading
Loading