Skip to content

Commit bbc48d1

Browse files
MarijnS95claude
andcommitted
Bind compute PSO inside ComputeEncoder::dispatch()
dispatch() now takes a PipelineState argument and binds it before issuing the dispatch, the same shape as RenderEncoder::drawInstanced(). Making the PSO a required parameter of dispatch() enforces at the type level that callers cannot forget to bind a pipeline (an invalid dispatch on every backend), and keeps the bind adjacent to the command it applies to instead of relying on a separate "last bound" state on the command list. It also hides per-API requirements behind the encoder: on Metal the threadgroup-size lookup from PSO reflection now lives inside dispatch() instead of leaking out as a setThreadGroupSize() call every caller had to remember, letting us drop the encoder-state helpers ThreadsPerGroup / setThreadGroupSize() entirely. With the bind moved, the per-backend PSO bind in createComputeCommands() (DX, VK, MTL) is removed. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9f2f14d commit bbc48d1

4 files changed

Lines changed: 74 additions & 71 deletions

File tree

include/API/Encoder.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,11 @@ class ComputeEncoder : public CommandEncoder {
7171
using CommandEncoder::CommandEncoder;
7272

7373
/// Dispatch a compute grid. GroupCount specifies how many workgroups to
74-
/// launch in each dimension. The workgroup size is derived from the bound
75-
/// pipeline state (e.g. the shader's numthreads attribute).
76-
virtual llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
77-
uint32_t GroupCountZ) = 0;
74+
/// launch in each dimension. The workgroup size is derived from \p PSO
75+
/// (e.g. the shader's numthreads attribute), which is also bound for the
76+
/// dispatch.
77+
virtual llvm::Error dispatch(const PipelineState &PSO, uint32_t GroupCountX,
78+
uint32_t GroupCountY, uint32_t GroupCountZ) = 0;
7879

7980
/// Copy \p Size bytes from \p Src at \p SrcOffset to \p Dst at
8081
/// \p DstOffset.

lib/API/DX/Device.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -627,12 +627,15 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
627627
void popDebugGroup() override {}
628628
void insertDebugSignpost(llvm::StringRef Label) override {}
629629

630-
llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
630+
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
631+
uint32_t GroupCountX, uint32_t GroupCountY,
631632
uint32_t GroupCountZ) override {
633+
const auto &DXPSO = llvm::cast<DXPipelineState>(PSO);
632634
addUAVBarrier();
633635
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
634636
GroupCountY, GroupCountZ)
635637
.str());
638+
CB.CmdList->SetPipelineState(DXPSO.PSO.Get());
636639
CB.CmdList->Dispatch(GroupCountX, GroupCountY, GroupCountZ);
637640
return llvm::Error::success();
638641
}
@@ -1908,7 +1911,6 @@ class DXDevice : public offloadtest::Device {
19081911
const DXPipelineState &DXPipeline =
19091912
llvm::cast<DXPipelineState>(*IS.Pipeline.get());
19101913
IS.CB->CmdList->SetComputeRootSignature(DXPipeline.RootSig.Get());
1911-
IS.CB->CmdList->SetPipelineState(DXPipeline.PSO.Get());
19121914

19131915
const uint32_t Inc = Device->GetDescriptorHandleIncrementSize(
19141916
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV);
@@ -1983,10 +1985,10 @@ class DXDevice : public offloadtest::Device {
19831985
if (!EncoderOrErr)
19841986
return EncoderOrErr.takeError();
19851987
auto &Encoder = *EncoderOrErr.get();
1986-
if (auto Err =
1987-
Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
1988-
P.DispatchParameters.DispatchGroupCount[1],
1989-
P.DispatchParameters.DispatchGroupCount[2]))
1988+
if (auto Err = Encoder.dispatch(
1989+
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
1990+
P.DispatchParameters.DispatchGroupCount[1],
1991+
P.DispatchParameters.DispatchGroupCount[2]))
19901992
return Err;
19911993
Encoder.endEncoding();
19921994
}

lib/API/MTL/MTLDevice.cpp

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -416,11 +416,6 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
416416
MTL::ComputeCommandEncoder *ComputeEnc = nullptr;
417417
MTL::BlitCommandEncoder *BlitEnc = nullptr;
418418

419-
/// Threadgroup size from shader reflection (the numthreads() attribute
420-
/// persisted in the transpiled Metallib). Must be set via
421-
/// setThreadGroupSize() before dispatching.
422-
MTL::Size ThreadsPerGroup = {1, 1, 1};
423-
424419
/// Accumulated barrier scope from commands recorded since the last barrier.
425420
MTL::BarrierScope PendingScope = MTL::BarrierScope(0);
426421

@@ -477,13 +472,6 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
477472

478473
MTL::ComputeCommandEncoder *getNative() const { return ComputeEnc; }
479474

480-
/// Set the threadgroup size for subsequent dispatch calls. The values must
481-
/// come from shader reflection (the numthreads() attribute in the HLSL
482-
/// source, persisted in the transpiled Metallib).
483-
void setThreadGroupSize(NS::UInteger X, NS::UInteger Y, NS::UInteger Z) {
484-
ThreadsPerGroup = MTL::Size(X, Y, Z);
485-
}
486-
487475
MTL::CommandEncoder *getActiveEncoder() const {
488476
if (ComputeEnc)
489477
return ComputeEnc;
@@ -507,14 +495,37 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
507495
NS::String::string(Label.data(), NS::UTF8StringEncoding));
508496
}
509497

510-
llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
498+
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
499+
uint32_t GroupCountX, uint32_t GroupCountY,
511500
uint32_t GroupCountZ) override {
501+
const auto &MTLPSO = llvm::cast<MTLPipelineState>(PSO);
502+
if (!MTLPSO.ComputePipeline)
503+
return llvm::createStringError(
504+
std::errc::invalid_argument,
505+
"PipelineState bound to dispatch() is not a compute pipeline.");
512506
if (auto Err = ensureComputeEncoder())
513507
return Err;
514508
flushBarrier();
515509
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
516510
GroupCountY, GroupCountZ)
517511
.str());
512+
ComputeEnc->setComputePipelineState(MTLPSO.ComputePipeline);
513+
514+
// Threadgroup size comes from numthreads() in the HLSL source, persisted
515+
// in the transpiled Metallib via shader reflection.
516+
MTL::Size ThreadsPerGroup(
517+
MTLPSO.ComputePipeline->maxTotalThreadsPerThreadgroup(), 1, 1);
518+
if (MTLPSO.Reflection) {
519+
IRVersionedCSInfo Info;
520+
if (IRShaderReflectionCopyComputeInfo(MTLPSO.Reflection.get(),
521+
IRReflectionVersion_1_0, &Info)) {
522+
ThreadsPerGroup =
523+
MTL::Size(Info.info_1_0.tg_size[0], Info.info_1_0.tg_size[1],
524+
Info.info_1_0.tg_size[2]);
525+
}
526+
IRShaderReflectionReleaseComputeInfo(&Info);
527+
}
528+
518529
const MTL::Size GridSize(ThreadsPerGroup.width * GroupCountX,
519530
ThreadsPerGroup.height * GroupCountY,
520531
ThreadsPerGroup.depth * GroupCountZ);
@@ -1255,7 +1266,6 @@ class MTLDevice : public offloadtest::Device {
12551266
MTL::ComputeCommandEncoder *NativeEncoder = Encoder.getNative();
12561267

12571268
const auto &PS = llvm::cast<MTLPipelineState>(IS.Pipeline.get());
1258-
NativeEncoder->setComputePipelineState(PS->ComputePipeline);
12591269
MTLGPUDescriptorHandle Handle = {};
12601270
if (IS.DescHeap) {
12611271
IS.DescHeap->bind(NativeEncoder);
@@ -1275,21 +1285,8 @@ class MTLDevice : public offloadtest::Device {
12751285
MTL::ResourceUsageRead |
12761286
MTL::ResourceUsageWrite);
12771287

1278-
NS::UInteger TGS[3] = {PS->ComputePipeline->maxTotalThreadsPerThreadgroup(),
1279-
1, 1};
1280-
if (PS->Reflection) {
1281-
IRVersionedCSInfo Info;
1282-
if (IRShaderReflectionCopyComputeInfo(PS->Reflection.get(),
1283-
IRReflectionVersion_1_0, &Info)) {
1284-
TGS[0] = Info.info_1_0.tg_size[0];
1285-
TGS[1] = Info.info_1_0.tg_size[1];
1286-
TGS[2] = Info.info_1_0.tg_size[2];
1287-
}
1288-
IRShaderReflectionReleaseComputeInfo(&Info);
1289-
}
1290-
Encoder.setThreadGroupSize(TGS[0], TGS[1], TGS[2]);
1291-
1292-
if (auto Err = Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
1288+
if (auto Err = Encoder.dispatch(*IS.Pipeline.get(),
1289+
P.DispatchParameters.DispatchGroupCount[0],
12931290
P.DispatchParameters.DispatchGroupCount[1],
12941291
P.DispatchParameters.DispatchGroupCount[2]))
12951292
return Err;

lib/API/VK/Device.cpp

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,32 @@ class VulkanCommandBuffer : public offloadtest::CommandBuffer {
687687
VulkanCommandBuffer() : CommandBuffer(GPUAPI::Vulkan) {}
688688
};
689689

690+
class VulkanPipelineState : public offloadtest::PipelineState {
691+
public:
692+
std::string Name;
693+
VkDevice Dev;
694+
VkPipeline Pipeline;
695+
VkPipelineLayout Layout;
696+
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts;
697+
698+
VulkanPipelineState(llvm::StringRef Name, VkDevice Dev, VkPipeline Pipeline,
699+
VkPipelineLayout Layout,
700+
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts)
701+
: offloadtest::PipelineState(GPUAPI::Vulkan), Name(Name.str()), Dev(Dev),
702+
Pipeline(Pipeline), Layout(Layout), SetLayouts(std::move(SetLayouts)) {}
703+
704+
~VulkanPipelineState() override {
705+
vkDestroyPipeline(Dev, Pipeline, nullptr);
706+
vkDestroyPipelineLayout(Dev, Layout, nullptr);
707+
for (VkDescriptorSetLayout L : SetLayouts)
708+
vkDestroyDescriptorSetLayout(Dev, L, nullptr);
709+
}
710+
711+
static bool classof(const offloadtest::PipelineState *B) {
712+
return B->getAPI() == GPUAPI::Vulkan;
713+
}
714+
};
715+
690716
class VKComputeEncoder : public offloadtest::ComputeEncoder {
691717
VulkanCommandBuffer &CB;
692718

@@ -713,13 +739,17 @@ class VKComputeEncoder : public offloadtest::ComputeEncoder {
713739
CB.insertDebugSignpost(Label);
714740
}
715741

716-
llvm::Error dispatch(uint32_t GroupCountX, uint32_t GroupCountY,
742+
llvm::Error dispatch(const offloadtest::PipelineState &PSO,
743+
uint32_t GroupCountX, uint32_t GroupCountY,
717744
uint32_t GroupCountZ) override {
745+
const auto &VKPSO = llvm::cast<VulkanPipelineState>(PSO);
718746
addDstBarrier(VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
719747
VK_ACCESS_SHADER_WRITE_BIT | VK_ACCESS_SHADER_READ_BIT);
720748
insertDebugSignpost(llvm::formatv("Dispatch [{0},{1},{2}]", GroupCountX,
721749
GroupCountY, GroupCountZ)
722750
.str());
751+
vkCmdBindPipeline(CB.CmdBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
752+
VKPSO.Pipeline);
723753
vkCmdDispatch(CB.CmdBuffer, GroupCountX, GroupCountY, GroupCountZ);
724754
return llvm::Error::success();
725755
}
@@ -749,32 +779,6 @@ VulkanCommandBuffer::createComputeEncoder() {
749779
return Enc;
750780
}
751781

752-
class VulkanPipelineState : public offloadtest::PipelineState {
753-
public:
754-
std::string Name;
755-
VkDevice Dev;
756-
VkPipeline Pipeline;
757-
VkPipelineLayout Layout;
758-
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts;
759-
760-
VulkanPipelineState(llvm::StringRef Name, VkDevice Dev, VkPipeline Pipeline,
761-
VkPipelineLayout Layout,
762-
llvm::SmallVector<VkDescriptorSetLayout> SetLayouts)
763-
: offloadtest::PipelineState(GPUAPI::Vulkan), Name(Name.str()), Dev(Dev),
764-
Pipeline(Pipeline), Layout(Layout), SetLayouts(std::move(SetLayouts)) {}
765-
766-
~VulkanPipelineState() override {
767-
vkDestroyPipeline(Dev, Pipeline, nullptr);
768-
vkDestroyPipelineLayout(Dev, Layout, nullptr);
769-
for (VkDescriptorSetLayout L : SetLayouts)
770-
vkDestroyDescriptorSetLayout(Dev, L, nullptr);
771-
}
772-
773-
static bool classof(const offloadtest::PipelineState *B) {
774-
return B->getAPI() == GPUAPI::Vulkan;
775-
}
776-
};
777-
778782
static VkAttachmentLoadOp getVkLoadOp(offloadtest::LoadAction Action) {
779783
switch (Action) {
780784
case offloadtest::LoadAction::Load:
@@ -2898,7 +2902,6 @@ class VulkanDevice : public offloadtest::Device {
28982902
: VK_PIPELINE_BIND_POINT_COMPUTE;
28992903
const VulkanPipelineState &VulkanPipeline =
29002904
llvm::cast<VulkanPipelineState>(*IS.Pipeline.get());
2901-
vkCmdBindPipeline(IS.CB->CmdBuffer, BindPoint, VulkanPipeline.Pipeline);
29022905
if (IS.DescriptorSets.size() > 0)
29032906
vkCmdBindDescriptorSets(
29042907
IS.CB->CmdBuffer, BindPoint, VulkanPipeline.Layout, 0,
@@ -2917,10 +2920,10 @@ class VulkanDevice : public offloadtest::Device {
29172920
if (!EncoderOrErr)
29182921
return EncoderOrErr.takeError();
29192922
auto &Encoder = *EncoderOrErr.get();
2920-
if (auto Err =
2921-
Encoder.dispatch(P.DispatchParameters.DispatchGroupCount[0],
2922-
P.DispatchParameters.DispatchGroupCount[1],
2923-
P.DispatchParameters.DispatchGroupCount[2]))
2923+
if (auto Err = Encoder.dispatch(
2924+
*IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0],
2925+
P.DispatchParameters.DispatchGroupCount[1],
2926+
P.DispatchParameters.DispatchGroupCount[2]))
29242927
return Err;
29252928
Encoder.endEncoding();
29262929
llvm::outs() << "Dispatched compute shader: { "

0 commit comments

Comments
 (0)