Skip to content

Commit e0a2df3

Browse files
MarijnS95claude
andcommitted
[VK] Add ray tracing pipeline, SBT, and DispatchRays bring-up
First per-backend bring-up in the PSO raytracing series (#1268). Adds the API surface (ComputeEncoder::dispatchRays, Device::createPipelineRT, Device::createShaderBindingTable, RayTracingPipelineCreateDesc) plus the Vulkan implementation behind it. D3D12 and Metal stub the new methods with not-yet-supported errors; their bring-up lands in follow-up PRs. The pre-existing YAML schema struct from PR #1270 is renamed ShaderBindingTable -> ShaderBindingTableDesc so the bare name is free for the runtime resource class (parallel to BLASDesc / TLASDesc vs AccelerationStructure). A new include/API/ShaderBindingTable.h holds the abstract runtime base; concrete backend SBT classes derive from it with LLVM-style classof / cast<>. The VulkanDevice's prior `RaytracingFunctions RT` lumped AS and RT pipeline entry points together. They split into two structs — `ASFunctions AS` and `RTPipelineFunctions RT` — matching the actual feature-gate split (AS+ray-query is a complete configuration on its own, RT pipeline is layered on top). `HasRayTracingSupport` renames to `HasASSupport`, and a separate `HasRTPipelineSupport` tracks the new VK_KHR_ray_tracing_pipeline extension. Vulkan bring-up: - Extension: VK_KHR_ray_tracing_pipeline is requested when reported, with VkPhysicalDeviceRayTracingPipelineFeaturesKHR chained into the pre-create feature query. After the query the gating rayTracingPipeline bool is checked; capture-replay / trace-rays- indirect / traversal-primitive-culling sub-features are cleared since the tests don't exercise them. - Function pointers: vkCreateRayTracingPipelinesKHR, vkGetRayTracingShaderGroupHandlesKHR, vkCmdTraceRaysKHR. - Properties: VkPhysicalDeviceRayTracingPipelinePropertiesKHR is cached at device-create time for SBT handle size / alignment / base-alignment. - VKRayTracingPipelineState derives from VulkanPipelineState; an IsRayTracing flag on the base lets the existing Vulkan cast<> path stay polymorphic without adding a new GPUAPI value. classof tests both the API and the flag. The derived class also carries a StringMap<uint32_t> resolving each shader EntryPoint or HitGroup Name to its index in the pipeline's group array, plus per-bucket counts so the SBT builder can slice the contiguous handle blob into raygen / miss / hit / callable regions. - createPipelineRT builds a single VkShaderModule (the DXIL library compiles to one SPIR-V module with multiple OpEntryPoints), then one VkPipelineShaderStageCreateInfo per Shader entry and one VkRayTracingShaderGroupCreateInfoKHR per general shader / hit group. Pipeline layout is shared with the compute path via createPipelineLayout, gated on all six RT stage flags so any binding can be consumed from any RT shader. - createShaderBindingTable allocates a host-visible coherent buffer big enough for four regions and lays out each entry as [handle bytes][localRootData bytes][padding-to-stride]. Per-region stride = align(handleSize + max-local-root-data-in-region, handleAlignment); per-region size = align(count * stride, baseAlignment). LocalRootData support comes free from the PR1 SBT schema; the test doesn't exercise it yet. Each region's VkStridedDeviceAddressRegionKHR derives from the buffer's vkGetBufferDeviceAddress. - dispatchRays binds the pipeline at VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, emits a pre-barrier with AS_READ + SHADER_READ/WRITE dst access into RAY_TRACING_SHADER_BIT_KHR, then calls vkCmdTraceRaysKHR with the SBT's four region structs. - createCommands picks the new bind point for RT pipelines so vkCmdBindDescriptorSets binds to the right point. executeProgram's isRayTracing branch builds a RayTracingPipelineCreateDesc from the YAML, calls createPipelineRT then createShaderBindingTable, and keeps both on InvocationState for the dispatch. raygen-roundtrip.test now expects DirectX/Metal/Clang to XFAIL; on a DXC + Vulkan combo with VK_KHR_ray_tracing_pipeline supported the test should PASS via this implementation. On the user's Linux + clang-dxc loop the test still XFAILs because clang-dxc doesn't yet lower [shader("raygeneration")] entry points to SPIR-V, so the Clang XFAIL token catches the compile failure. CI on a working DXC install will exercise the runtime path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ca07851 commit e0a2df3

10 files changed

Lines changed: 705 additions & 46 deletions

File tree

include/API/Device.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "API/Capabilities.h"
2121
#include "API/CommandBuffer.h"
2222
#include "API/RenderPass.h"
23+
#include "API/ShaderBindingTable.h"
2324
#include "API/Texture.h"
2425

2526
#include "Support/Pipeline.h"
@@ -83,6 +84,21 @@ struct ShaderContainer {
8384
llvm::SmallVector<SpecializationConstant> SpecializationConstants;
8485
};
8586

87+
struct RayTracingShader {
88+
Stages Stage;
89+
std::string EntryPoint;
90+
};
91+
92+
struct RayTracingPipelineCreateDesc {
93+
// All RT shaders are compiled into a single DXIL library; every entry in
94+
// `Shaders` references this same blob via the backend's library-loading
95+
// path.
96+
const llvm::MemoryBuffer *Library = nullptr;
97+
llvm::SmallVector<RayTracingShader> Shaders;
98+
llvm::SmallVector<HitGroup> HitGroups;
99+
RayTracingPipelineConfig Config;
100+
};
101+
86102
struct TraditionalRasterPipelineCreateDesc {
87103
llvm::SmallVector<InputLayoutDesc> InputLayout;
88104
llvm::SmallVector<Format> RTFormats;
@@ -215,6 +231,14 @@ class Device {
215231
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
216232
const TraditionalRasterPipelineCreateDesc &Desc) = 0;
217233

234+
virtual llvm::Expected<std::unique_ptr<PipelineState>>
235+
createPipelineRT(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
236+
const RayTracingPipelineCreateDesc &Desc) = 0;
237+
238+
virtual llvm::Expected<std::unique_ptr<ShaderBindingTable>>
239+
createShaderBindingTable(const PipelineState &PSO,
240+
const ShaderBindingTableDesc &Desc) = 0;
241+
218242
virtual llvm::Expected<std::unique_ptr<Fence>>
219243
createFence(llvm::StringRef Name) = 0;
220244

include/API/Encoder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace offloadtest {
2424
class Buffer;
2525
class PipelineState;
2626
class AccelerationStructure;
27+
class ShaderBindingTable;
2728
struct BLASBuildRequest;
2829
struct TLASBuildRequest;
2930

@@ -105,6 +106,17 @@ class ComputeEncoder : public CommandEncoder {
105106
/// Metal). A barrier covering AS-build writes is implicitly emitted before
106107
/// any subsequent command that reads from the freshly-built structures.
107108
virtual llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) = 0;
109+
110+
/// Trace rays from a RayTracing pipeline. \p PSO must have been created via
111+
/// Device::createPipelineRT and \p SBT via Device::createShaderBindingTable
112+
/// on that same PSO. \p Width, \p Height, \p Depth are the dispatch
113+
/// dimensions passed through to the backend's DispatchRays equivalent
114+
/// (D3D12 DispatchRays, Vulkan vkCmdTraceRaysKHR, Metal compute dispatch
115+
/// after metal_irconverter lowering).
116+
virtual llvm::Error dispatchRays(const PipelineState &PSO,
117+
const ShaderBindingTable &SBT,
118+
uint32_t Width, uint32_t Height,
119+
uint32_t Depth) = 0;
108120
};
109121

110122
struct Viewport {

include/API/ShaderBindingTable.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- ShaderBindingTable.h - Offload RT Shader Binding Table -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef OFFLOADTEST_API_SHADERBINDINGTABLE_H
10+
#define OFFLOADTEST_API_SHADERBINDINGTABLE_H
11+
12+
#include "API/API.h"
13+
14+
namespace offloadtest {
15+
16+
/// Runtime shader binding table built from a RayTracing PipelineState plus a
17+
/// ShaderBindingTableDesc. Concrete subclasses (one per backend) hold the
18+
/// device-side records and any address ranges needed by the backend's
19+
/// DispatchRays call.
20+
class ShaderBindingTable {
21+
GPUAPI API;
22+
23+
public:
24+
virtual ~ShaderBindingTable();
25+
ShaderBindingTable(const ShaderBindingTable &) = delete;
26+
ShaderBindingTable &operator=(const ShaderBindingTable &) = delete;
27+
28+
GPUAPI getAPI() const { return API; }
29+
30+
protected:
31+
explicit ShaderBindingTable(GPUAPI API) : API(API) {}
32+
};
33+
34+
} // namespace offloadtest
35+
36+
#endif // OFFLOADTEST_API_SHADERBINDINGTABLE_H

include/Support/Pipeline.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ struct SBTEntry {
592592
llvm::SmallVector<uint8_t> LocalRootData;
593593
};
594594

595-
struct ShaderBindingTable {
595+
struct ShaderBindingTableDesc {
596596
SBTEntry RayGen;
597597
llvm::SmallVector<SBTEntry> Miss;
598598
llvm::SmallVector<SBTEntry> HitGroup;
@@ -614,7 +614,7 @@ struct Pipeline {
614614
AccelerationStructureDescs AccelStructs;
615615
std::optional<RayTracingPipelineConfig> RTConfig;
616616
llvm::SmallVector<HitGroup> HitGroups;
617-
std::optional<ShaderBindingTable> SBT;
617+
std::optional<ShaderBindingTableDesc> SBT;
618618

619619
uint32_t getVertexCount() const {
620620
if (DispatchParameters.VertexCount)
@@ -825,8 +825,8 @@ template <> struct MappingTraits<offloadtest::SBTEntry> {
825825
static void mapping(IO &I, offloadtest::SBTEntry &E);
826826
};
827827

828-
template <> struct MappingTraits<offloadtest::ShaderBindingTable> {
829-
static void mapping(IO &I, offloadtest::ShaderBindingTable &S);
828+
template <> struct MappingTraits<offloadtest::ShaderBindingTableDesc> {
829+
static void mapping(IO &I, offloadtest::ShaderBindingTableDesc &S);
830830
};
831831

832832
template <> struct ScalarEnumerationTraits<offloadtest::Rule> {

lib/API/DX/Device.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,12 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
730730
// ID3D12Device5 entry point and helper allocators.
731731
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
732732

733+
llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &,
734+
uint32_t, uint32_t, uint32_t) override {
735+
return llvm::createStringError(
736+
"RayTracing dispatchRays not yet supported on DirectX");
737+
}
738+
733739
void endEncodingImpl() override { popDebugGroup(); }
734740
};
735741

@@ -1356,6 +1362,20 @@ class DXDevice : public offloadtest::Device {
13561362
return std::make_unique<DXPipelineState>(Name, RootSig, PSO, std::nullopt);
13571363
}
13581364

1365+
llvm::Expected<std::unique_ptr<PipelineState>>
1366+
createPipelineRT(llvm::StringRef, const BindingsDesc &,
1367+
const RayTracingPipelineCreateDesc &) override {
1368+
return llvm::createStringError(
1369+
"RayTracing pipeline state not yet supported on DirectX");
1370+
}
1371+
1372+
llvm::Expected<std::unique_ptr<ShaderBindingTable>>
1373+
createShaderBindingTable(const PipelineState &,
1374+
const ShaderBindingTableDesc &) override {
1375+
return llvm::createStringError(
1376+
"RayTracing shader binding table not yet supported on DirectX");
1377+
}
1378+
13591379
llvm::Expected<std::unique_ptr<offloadtest::Fence>>
13601380
createFence(llvm::StringRef Name) override {
13611381
return DXFence::create(Device.Get(), Name);

lib/API/Device.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "API/Device.h"
1313
#include "API/Encoder.h"
1414
#include "API/FormatConversion.h"
15+
#include "API/ShaderBindingTable.h"
1516

1617
#include "Config.h"
1718

@@ -39,6 +40,8 @@ RenderPass::~RenderPass() {}
3940

4041
AccelerationStructure::~AccelerationStructure() {}
4142

43+
ShaderBindingTable::~ShaderBindingTable() {}
44+
4245
Device::~Device() {}
4346

4447
llvm::Expected<llvm::SmallVector<std::unique_ptr<Device>>>

lib/API/MTL/MTLDevice.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
602602
// MTL::Device handle (used to allocate scratch and instance buffers).
603603
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
604604

605+
llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &,
606+
uint32_t, uint32_t, uint32_t) override {
607+
return llvm::createStringError(
608+
"RayTracing dispatchRays not yet supported on Metal");
609+
}
610+
605611
/// Lazily transition into an AccelerationStructureCommandEncoder; mirrors
606612
/// the existing compute↔blit lazy switch.
607613
llvm::Error ensureASEncoder() {
@@ -1662,6 +1668,20 @@ class MTLDevice : public offloadtest::Device {
16621668

16631669
Queue &getGraphicsQueue() override { return GraphicsQueue; }
16641670

1671+
llvm::Expected<std::unique_ptr<PipelineState>>
1672+
createPipelineRT(llvm::StringRef, const BindingsDesc &,
1673+
const RayTracingPipelineCreateDesc &) override {
1674+
return llvm::createStringError(
1675+
"RayTracing pipeline state not yet supported on Metal");
1676+
}
1677+
1678+
llvm::Expected<std::unique_ptr<ShaderBindingTable>>
1679+
createShaderBindingTable(const PipelineState &,
1680+
const ShaderBindingTableDesc &) override {
1681+
return llvm::createStringError(
1682+
"RayTracing shader binding table not yet supported on Metal");
1683+
}
1684+
16651685
llvm::Expected<std::unique_ptr<offloadtest::Fence>>
16661686
createFence(llvm::StringRef Name) override {
16671687
return MTLFence::create(Device, Name);

0 commit comments

Comments
 (0)