Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions include/API/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ class ComputeEncoder : public CommandEncoder {
using CommandEncoder::CommandEncoder;

/// Dispatch a compute grid. GroupCount specifies how many workgroups to
/// launch in each dimension. The workgroup size is derived from the bound
/// pipeline state (e.g. the shader's numthreads attribute).
virtual llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
uint32_t GroupCountZ) = 0;
/// launch in each dimension. The workgroup size is derived from \p PSO
/// (e.g. the shader's numthreads attribute), which is also bound for the
/// dispatch.
virtual llvm::Error dispatch(const PipelineState &PSO, uint32_t GroupCountX,
uint32_t GroupCountY, uint32_t GroupCountZ) = 0;

/// Copy \p Size bytes from \p Src at \p SrcOffset to \p Dst at
/// \p DstOffset.
Expand Down
14 changes: 8 additions & 6 deletions lib/API/DX/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,12 +627,15 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
void popDebugGroup() override {}
void insertDebugSignpost(llvm::StringRef Label) override {}

llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
uint32_t GroupCountX, uint32_t GroupCountY,
uint32_t GroupCountZ) override {
const auto &DXPSO = llvm::cast<DXPipelineState>(PSO);
addUAVBarrier();
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
GroupCountY, GroupCountZ)
.str());
CB.CmdList->SetPipelineState(DXPSO.PSO.Get());
CB.CmdList->Dispatch(GroupCountX, GroupCountY, GroupCountZ);
return llvm::Error::success();
}
Expand Down Expand Up @@ -1908,7 +1911,6 @@ class DXDevice : public offloadtest::Device {
const DXPipelineState &DXPipeline =
llvm::cast<DXPipelineState>(*IS.Pipeline.get());
IS.CB->CmdList->SetComputeRootSignature(DXPipeline.RootSig.Get());
IS.CB->CmdList->SetPipelineState(DXPipeline.PSO.Get());

const uint32_t Inc = Device->GetDescriptorHandleIncrementSize(
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
Expand Down Expand Up @@ -1983,10 +1985,10 @@ class DXDevice : public offloadtest::Device {
if (!EncoderOrErr)
return EncoderOrErr.takeError();
auto &Encoder = *EncoderOrErr.get();
if (auto Err =
Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
if (auto Err = Encoder.dispatch(
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
return Err;
Encoder.endEncoding();
}
Expand Down
75 changes: 36 additions & 39 deletions lib/API/MTL/MTLDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,14 @@ class MTLPipelineState : public offloadtest::PipelineState {
std::string Name;
IRRootSignaturePtr RootSig;
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer;
IRShaderReflectionPtr Reflection;
MTL::ComputePipelineState *ComputePipeline = nullptr;
MTL::RenderPipelineState *RenderPipeline = nullptr;

// Compute pipeline only state. Threadgroup size comes from numthreads() in
// the HLSL source and is captured from shader reflection at pipeline
// creation, so dispatch() doesn't need to re-query reflection each time.
MTL::Size ThreadsPerGroup = MTL::Size(1, 1, 1);

// Rasterization pipeline only state.
// These are part of the pipeline in DX and VK, but dynamic state in Metal.
// To have a shared API we store these here and set the state when the
Expand All @@ -246,11 +250,11 @@ class MTLPipelineState : public offloadtest::PipelineState {

MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
IRShaderReflectionPtr Reflection,
MTL::ComputePipelineState *ComputePipeline)
MTL::ComputePipelineState *ComputePipeline,
MTL::Size ThreadsPerGroup)
: offloadtest::PipelineState(GPUAPI::Metal), Name(Name),
RootSig(std::move(RootSig)), ArgBuffer(std::move(ArgBuffer)),
Reflection(std::move(Reflection)), ComputePipeline(ComputePipeline) {}
ComputePipeline(ComputePipeline), ThreadsPerGroup(ThreadsPerGroup) {}

MTLPipelineState(llvm::StringRef Name, IRRootSignaturePtr RootSig,
std::unique_ptr<MTLTopLevelArgumentBuffer> ArgBuffer,
Expand Down Expand Up @@ -416,11 +420,6 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
MTL::ComputeCommandEncoder *ComputeEnc = nullptr;
MTL::BlitCommandEncoder *BlitEnc = nullptr;

/// Threadgroup size from shader reflection (the numthreads() attribute
/// persisted in the transpiled Metallib). Must be set via
/// setThreadGroupSize() before dispatching.
MTL::Size ThreadsPerGroup = {1, 1, 1};

/// Accumulated barrier scope from commands recorded since the last barrier.
MTL::BarrierScope PendingScope = MTL::BarrierScope(0);

Expand Down Expand Up @@ -477,13 +476,6 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {

MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; }

/// Set the threadgroup size for subsequent dispatch calls. The values must
/// come from shader reflection (the numthreads() attribute in the HLSL
/// source, persisted in the transpiled Metallib).
void setThreadGroupSize(NS::UInteger X, NS::UInteger Y, NS::UInteger Z) {
ThreadsPerGroup = MTL::Size(X, Y, Z);
}

MTL::CommandEncoder *getActiveEncoder() const {
if (ComputeEnc)
return ComputeEnc;
Expand All @@ -507,18 +499,26 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
NS::String::string(Label.data(), NS::UTF8StringEncoding));
}

llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
uint32_t GroupCountX, uint32_t GroupCountY,
uint32_t GroupCountZ) override {
const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
if (!MTLPSO.ComputePipeline)
return llvm::createStringError(
std::errc::invalid_argument,
"PipelineState bound to dispatch() is not a compute pipeline.");
if (auto Err = ensureComputeEncoder())
return Err;
flushBarrier();
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
GroupCountY, GroupCountZ)
.str());
const MTL::Size GridSize(ThreadsPerGroup.width * GroupCountX,
ThreadsPerGroup.height * GroupCountY,
ThreadsPerGroup.depth * GroupCountZ);
ComputeEnc->dispatchThreads(GridSize, ThreadsPerGroup);
ComputeEnc->setComputePipelineState(MTLPSO.ComputePipeline);

const MTL::Size GridSize(MTLPSO.ThreadsPerGroup.width * GroupCountX,
MTLPSO.ThreadsPerGroup.height * GroupCountY,
MTLPSO.ThreadsPerGroup.depth * GroupCountZ);
ComputeEnc->dispatchThreads(GridSize, MTLPSO.ThreadsPerGroup);
addBarrierScope(MTL::BarrierScopeBuffers | MTL::BarrierScopeTextures);
return llvm::Error::success();
}
Expand Down Expand Up @@ -1255,7 +1255,6 @@ class MTLDevice : public offloadtest::Device {
MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative();

const auto &PS = llvm::cast<MTLPipelineState>(IS.Pipeline.get());
NativeEncoder->setComputePipelineState(PS->ComputePipeline);
MTLGPUDescriptorHandle Handle = {};
if (IS.DescHeap) {
IS.DescHeap->bind(NativeEncoder);
Expand All @@ -1275,21 +1274,8 @@ class MTLDevice : public offloadtest::Device {
MTL::ResourceUsageRead |
MTL::ResourceUsageWrite);

NS::UInteger TGS[3] = {PS->ComputePipeline->maxTotalThreadsPerThreadgroup(),
1, 1};
if (PS->Reflection) {
IRVersionedCSInfo Info;
if (IRShaderReflectionCopyComputeInfo(PS->Reflection.get(),
IRReflectionVersion_1_0, &Info)) {
TGS[0] = Info.info_1_0.tg_size[0];
TGS[1] = Info.info_1_0.tg_size[1];
TGS[2] = Info.info_1_0.tg_size[2];
}
IRShaderReflectionReleaseComputeInfo(&Info);
}
Encoder.setThreadGroupSize(TGS[0], TGS[1], TGS[2]);

if (auto Err = Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
if (auto Err = Encoder.dispatch(*IS.Pipeline.get(),
P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
return Err;
Expand Down Expand Up @@ -1542,9 +1528,20 @@ class MTLDevice : public offloadtest::Device {
if (Error)
return toError(Error);

IRVersionedCSInfo Info;
if (!IRShaderReflectionCopyComputeInfo(MetalIR->Reflection.get(),
IRReflectionVersion_1_0, &Info))
return llvm::createStringError(
"Failed to read compute reflection for entry point '%s'; cannot "
"determine threadgroup size from numthreads().",
CS.EntryPoint.c_str());
const MTL::Size ThreadsPerGroup(Info.info_1_0.tg_size[0],
Info.info_1_0.tg_size[1],
Info.info_1_0.tg_size[2]);
IRShaderReflectionReleaseComputeInfo(&Info);

return std::make_unique<MTLPipelineState>(
Name, std::move(RootSig), std::move(ArgBuffer),
std::move(MetalIR->Reflection), PSO);
Name, std::move(RootSig), std::move(ArgBuffer), PSO, ThreadsPerGroup);
}

llvm::Expected<std::unique_ptr<PipelineState>>
Expand Down
67 changes: 35 additions & 32 deletions lib/API/VK/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,32 @@ class VulkanCommandBuffer : public offloadtest::CommandBuffer {
VulkanCommandBuffer() : CommandBuffer(GPUAPI::Vulkan) {}
};

class VulkanPipelineState : public offloadtest::PipelineState {
public:
std::string Name;
VkDevice Dev;
VkPipeline Pipeline;
VkPipelineLayout Layout;
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts;

VulkanPipelineState(llvm::StringRef Name, VkDevice Dev, VkPipeline Pipeline,
VkPipelineLayout Layout,
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts)
: offloadtest::PipelineState(GPUAPI::Vulkan), Name(Name.str()), Dev(Dev),
Pipeline(Pipeline), Layout(Layout), SetLayouts(std::move(SetLayouts)) {}

~VulkanPipelineState() override {
vkDestroyPipeline(Dev, Pipeline, nullptr);
vkDestroyPipelineLayout(Dev, Layout, nullptr);
for (VkDescriptorSetLayout L : SetLayouts)
vkDestroyDescriptorSetLayout(Dev, L, nullptr);
}

static bool classof(const offloadtest::PipelineState *B) {
return B->getAPI() == GPUAPI::Vulkan;
}
};

class VKComputeEncoder : public offloadtest::ComputeEncoder {
VulkanCommandBuffer &CB;

Expand All @@ -713,13 +739,17 @@ class VKComputeEncoder : public offloadtest::ComputeEncoder {
CB.insertDebugSignpost(Label);
}

llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
uint32_t GroupCountX, uint32_t GroupCountY,
uint32_t GroupCountZ) override {
const auto &VKPSO = llvm::cast<VulkanPipelineState>(PSO);
addDstBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT);
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
GroupCountY, GroupCountZ)
.str());
vkCmdBindPipeline(CB.CmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
VKPSO.Pipeline);
vkCmdDispatch(CB.CmdBuffer, GroupCountX, GroupCountY, GroupCountZ);
return llvm::Error::success();
}
Expand Down Expand Up @@ -749,32 +779,6 @@ VulkanCommandBuffer::createComputeEncoder() {
return Enc;
}

class VulkanPipelineState : public offloadtest::PipelineState {
public:
std::string Name;
VkDevice Dev;
VkPipeline Pipeline;
VkPipelineLayout Layout;
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts;

VulkanPipelineState(llvm::StringRef Name, VkDevice Dev, VkPipeline Pipeline,
VkPipelineLayout Layout,
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts)
: offloadtest::PipelineState(GPUAPI::Vulkan), Name(Name.str()), Dev(Dev),
Pipeline(Pipeline), Layout(Layout), SetLayouts(std::move(SetLayouts)) {}

~VulkanPipelineState() override {
vkDestroyPipeline(Dev, Pipeline, nullptr);
vkDestroyPipelineLayout(Dev, Layout, nullptr);
for (VkDescriptorSetLayout L : SetLayouts)
vkDestroyDescriptorSetLayout(Dev, L, nullptr);
}

static bool classof(const offloadtest::PipelineState *B) {
return B->getAPI() == GPUAPI::Vulkan;
}
};

static VkAttachmentLoadOp getVkLoadOp(offloadtest::LoadAction Action) {
switch (Action) {
case offloadtest::LoadAction::Load:
Expand Down Expand Up @@ -2898,7 +2902,6 @@ class VulkanDevice : public offloadtest::Device {
: VK_PIPELINE_BIND_POINT_COMPUTE;
const VulkanPipelineState &VulkanPipeline =
llvm::cast<VulkanPipelineState>(*IS.Pipeline.get());
vkCmdBindPipeline(IS.CB->CmdBuffer, BindPoint, VulkanPipeline.Pipeline);
if (IS.DescriptorSets.size() > 0)
vkCmdBindDescriptorSets(
IS.CB->CmdBuffer, BindPoint, VulkanPipeline.Layout, 0,
Expand All @@ -2917,10 +2920,10 @@ class VulkanDevice : public offloadtest::Device {
if (!EncoderOrErr)
return EncoderOrErr.takeError();
auto &Encoder = *EncoderOrErr.get();
if (auto Err =
Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
if (auto Err = Encoder.dispatch(
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
P.DispatchParameters.DispatchGroupCount[1],
P.DispatchParameters.DispatchGroupCount[2]))
return Err;
Encoder.endEncoding();
llvm::outs() << "Dispatched compute shader: { "
Expand Down