Skip to content

Commit cc34fbb

Browse files
MarijnS95claude
authored andcommitted
[VK] Add ray tracing pipeline, SBT, and DispatchRays bring-up
First per-backend bring-up in the PSO raytracing series (#1268). Adds the API surface (ComputeEncoder::dispatchRays, Device::createPipelineRT, Device::createShaderBindingTable, RayTracingPipelineCreateDesc) plus the Vulkan implementation behind it. D3D12 and Metal stub the new methods with not-yet-supported errors; their bring-up lands in follow-up PRs. The pre-existing YAML schema struct from PR #1270 is renamed ShaderBindingTable -> ShaderBindingTableDesc so the bare name is free for the runtime resource class (parallel to BLASDesc / TLASDesc vs AccelerationStructure). A new include/API/ShaderBindingTable.h holds the abstract runtime base; concrete backend SBT classes derive from it with LLVM-style classof / cast<>. The VulkanDevice's prior `RaytracingFunctions RT` lumped AS and RT pipeline entry points together. They split into two structs — `ASFunctions AS` and `RTPipelineFunctions RT` — matching the actual feature-gate split (AS+ray-query is a complete configuration on its own, RT pipeline is layered on top). `HasRayTracingSupport` renames to `HasASSupport`, and a separate `HasRTPipelineSupport` tracks the new VK_KHR_ray_tracing_pipeline extension. Vulkan bring-up: - Extension: VK_KHR_ray_tracing_pipeline is requested when reported, with VkPhysicalDeviceRayTracingPipelineFeaturesKHR chained into the pre-create feature query. After the query the gating rayTracingPipeline bool is checked; capture-replay / trace-rays- indirect / traversal-primitive-culling sub-features are cleared since the tests don't exercise them. - Function pointers: vkCreateRayTracingPipelinesKHR, vkGetRayTracingShaderGroupHandlesKHR, vkCmdTraceRaysKHR. - Properties: VkPhysicalDeviceRayTracingPipelinePropertiesKHR is cached at device-create time for SBT handle size / alignment / base-alignment. - VKRayTracingPipelineState derives from VulkanPipelineState; an IsRayTracing flag on the base lets the existing Vulkan cast<> path stay polymorphic without adding a new GPUAPI value. classof tests both the API and the flag. The derived class also carries a StringMap<uint32_t> resolving each shader EntryPoint or HitGroup Name to its index in the pipeline's group array, plus per-bucket counts so the SBT builder can slice the contiguous handle blob into raygen / miss / hit / callable regions. - createPipelineRT builds a single VkShaderModule (the DXIL library compiles to one SPIR-V module with multiple OpEntryPoints), then one VkPipelineShaderStageCreateInfo per Shader entry and one VkRayTracingShaderGroupCreateInfoKHR per general shader / hit group. Pipeline layout is shared with the compute path via createPipelineLayout, gated on all six RT stage flags so any binding can be consumed from any RT shader. - createShaderBindingTable allocates a host-visible coherent buffer big enough for four regions and lays out each entry as [handle bytes][localRootData bytes][padding-to-stride]. Per-region stride = align(handleSize + max-local-root-data-in-region, handleAlignment); per-region size = align(count * stride, baseAlignment). LocalRootData support comes free from the PR1 SBT schema; the test doesn't exercise it yet. Each region's VkStridedDeviceAddressRegionKHR derives from the buffer's vkGetBufferDeviceAddress. - dispatchRays binds the pipeline at VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, emits a pre-barrier with AS_READ + SHADER_READ/WRITE dst access into RAY_TRACING_SHADER_BIT_KHR, then calls vkCmdTraceRaysKHR with the SBT's four region structs. - createCommands picks the new bind point for RT pipelines so vkCmdBindDescriptorSets binds to the right point. executeProgram's isRayTracing branch builds a RayTracingPipelineCreateDesc from the YAML, calls createPipelineRT then createShaderBindingTable, and keeps both on InvocationState for the dispatch. raygen-roundtrip.test now expects DirectX/Metal/Clang to XFAIL; on a DXC + Vulkan combo with VK_KHR_ray_tracing_pipeline supported the test should PASS via this implementation. On the user's Linux + clang-dxc loop the test still XFAILs because clang-dxc doesn't yet lower [shader("raygeneration")] entry points to SPIR-V, so the Clang XFAIL token catches the compile failure. CI on a working DXC install will exercise the runtime path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1ddd4d9 commit cc34fbb

10 files changed

Lines changed: 755 additions & 40 deletions

File tree

include/API/Device.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "API/Capabilities.h"
2121
#include "API/CommandBuffer.h"
2222
#include "API/RenderPass.h"
23+
#include "API/ShaderBindingTable.h"
2324
#include "API/Texture.h"
2425

2526
#include "Support/Pipeline.h"
@@ -84,6 +85,21 @@ struct ShaderContainer {
8485
llvm::SmallVector<SpecializationConstant> SpecializationConstants;
8586
};
8687

88+
struct RayTracingShader {
89+
Stages Stage;
90+
std::string EntryPoint;
91+
};
92+
93+
struct RayTracingPipelineCreateDesc {
94+
// All RT shaders are compiled into a single DXIL library; every entry in
95+
// `Shaders` references this same blob via the backend's library-loading
96+
// path.
97+
const llvm::MemoryBuffer *Library = nullptr;
98+
llvm::SmallVector<RayTracingShader> Shaders;
99+
llvm::SmallVector<HitGroup> HitGroups;
100+
RayTracingPipelineConfig Config;
101+
};
102+
87103
struct TraditionalRasterPipelineCreateDesc {
88104
llvm::SmallVector<InputLayoutDesc> InputLayout;
89105
llvm::SmallVector<Format> RTFormats;
@@ -259,6 +275,14 @@ class Device {
259275
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
260276
const MeshShaderRasterPipelineCreateDesc &Desc) = 0;
261277

278+
virtual llvm::Expected<std::unique_ptr<PipelineState>>
279+
createPipelineRT(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
280+
const RayTracingPipelineCreateDesc &Desc) = 0;
281+
282+
virtual llvm::Expected<std::unique_ptr<ShaderBindingTable>>
283+
createShaderBindingTable(const PipelineState &PSO,
284+
const ShaderBindingTableDesc &Desc) = 0;
285+
262286
virtual llvm::Expected<std::unique_ptr<Fence>>
263287
createFence(llvm::StringRef Name) = 0;
264288

include/API/Encoder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class Buffer;
2525
class Texture;
2626
class PipelineState;
2727
class AccelerationStructure;
28+
class ShaderBindingTable;
2829
struct BLASBuildRequest;
2930
struct TLASBuildRequest;
3031

@@ -110,6 +111,17 @@ class ComputeEncoder : public CommandEncoder {
110111
/// Metal). A barrier covering AS-build writes is implicitly emitted before
111112
/// any subsequent command that reads from the freshly-built structures.
112113
virtual llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) = 0;
114+
115+
/// Trace rays from a RayTracing pipeline. \p PSO must have been created via
116+
/// Device::createPipelineRT and \p SBT via Device::createShaderBindingTable
117+
/// on that same PSO. \p Width, \p Height, \p Depth are the dispatch
118+
/// dimensions passed through to the backend's DispatchRays equivalent
119+
/// (D3D12 DispatchRays, Vulkan vkCmdTraceRaysKHR, Metal compute dispatch
120+
/// after metal_irconverter lowering).
121+
virtual llvm::Error dispatchRays(const PipelineState &PSO,
122+
const ShaderBindingTable &SBT,
123+
uint32_t Width, uint32_t Height,
124+
uint32_t Depth) = 0;
113125
};
114126

115127
struct Viewport {

include/API/ShaderBindingTable.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//===- ShaderBindingTable.h - Offload RT Shader Binding Table -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef OFFLOADTEST_API_SHADERBINDINGTABLE_H
10+
#define OFFLOADTEST_API_SHADERBINDINGTABLE_H
11+
12+
#include "API/API.h"
13+
14+
#include <cstdint>
15+
16+
namespace offloadtest {
17+
18+
struct ShaderBindingTableDesc;
19+
20+
/// Runtime shader binding table built from a RayTracing PipelineState plus a
21+
/// ShaderBindingTableDesc. Concrete subclasses (one per backend) hold the
22+
/// device-side records and any address ranges needed by the backend's
23+
/// DispatchRays call.
24+
class ShaderBindingTable {
25+
GPUAPI API;
26+
27+
public:
28+
virtual ~ShaderBindingTable();
29+
ShaderBindingTable(const ShaderBindingTable &) = delete;
30+
ShaderBindingTable &operator=(const ShaderBindingTable &) = delete;
31+
32+
GPUAPI getAPI() const { return API; }
33+
34+
protected:
35+
explicit ShaderBindingTable(GPUAPI API) : API(API) {}
36+
};
37+
38+
/// Per-region SBT layout numbers.
39+
///
40+
/// Every backend lays an SBT out as four concatenated regions (raygen, miss,
41+
/// hit-group, callable). Within a region every record is
42+
/// `[shader-identifier][LocalRootData][padding-to-stride]`, where stride is
43+
/// `align(identifierSize + max-LocalRootData-in-region, RecordAlign)` and
44+
/// the region itself is aligned to `BaseAlign`. The numbers match the
45+
/// alignment rules of both D3D12
46+
/// (`D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES` / `…_RECORD_BYTE_ALIGNMENT` /
47+
/// `…_TABLE_BYTE_ALIGNMENT`) and Vulkan
48+
/// (`shaderGroupHandleSize` / `shaderGroupHandleAlignment` /
49+
/// `shaderGroupBaseAlignment`); backend-specific constants are passed in.
50+
struct SBTRegionLayout {
51+
uint32_t Stride = 0;
52+
uint32_t Size = 0;
53+
uint32_t Offset = 0; // byte offset from the start of the SBT buffer
54+
};
55+
56+
struct SBTLayout {
57+
SBTRegionLayout RayGen;
58+
SBTRegionLayout Miss;
59+
SBTRegionLayout HitGroup;
60+
SBTRegionLayout Callable;
61+
uint32_t TotalSize = 0;
62+
};
63+
64+
/// Compute the per-region layout for an SBT description.
65+
///
66+
/// \p IdentifierSize is the size of one shader identifier (32 bytes on
67+
/// D3D12; `shaderGroupHandleSize` on Vulkan).
68+
/// \p RecordAlign is the per-record alignment (D3D12: 32 bytes; Vulkan:
69+
/// `shaderGroupHandleAlignment`).
70+
/// \p BaseAlign is the per-region alignment (D3D12: 64 bytes; Vulkan:
71+
/// `shaderGroupBaseAlignment`).
72+
SBTLayout computeSBTLayout(uint32_t IdentifierSize, uint32_t RecordAlign,
73+
uint32_t BaseAlign,
74+
const ShaderBindingTableDesc &Desc);
75+
76+
} // namespace offloadtest
77+
78+
#endif // OFFLOADTEST_API_SHADERBINDINGTABLE_H

include/Support/Pipeline.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ struct SBTEntry {
593593
llvm::SmallVector<uint8_t> LocalRootData;
594594
};
595595

596-
struct ShaderBindingTable {
596+
struct ShaderBindingTableDesc {
597597
SBTEntry RayGen;
598598
llvm::SmallVector<SBTEntry> Miss;
599599
llvm::SmallVector<SBTEntry> HitGroup;
@@ -615,7 +615,7 @@ struct Pipeline {
615615
AccelerationStructureDescs AccelStructs;
616616
std::optional<RayTracingPipelineConfig> RTConfig;
617617
llvm::SmallVector<HitGroup> HitGroups;
618-
std::optional<ShaderBindingTable> SBT;
618+
std::optional<ShaderBindingTableDesc> SBT;
619619

620620
uint32_t getVertexCount() const {
621621
if (DispatchParameters.VertexCount)
@@ -826,8 +826,8 @@ template <> struct MappingTraits<offloadtest::SBTEntry> {
826826
static void mapping(IO &I, offloadtest::SBTEntry &E);
827827
};
828828

829-
template <> struct MappingTraits<offloadtest::ShaderBindingTable> {
830-
static void mapping(IO &I, offloadtest::ShaderBindingTable &S);
829+
template <> struct MappingTraits<offloadtest::ShaderBindingTableDesc> {
830+
static void mapping(IO &I, offloadtest::ShaderBindingTableDesc &S);
831831
};
832832

833833
template <> struct ScalarEnumerationTraits<offloadtest::Rule> {

lib/API/DX/Device.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,12 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
846846
// ID3D12Device5 entry point and helper allocators.
847847
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
848848

849+
llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &,
850+
uint32_t, uint32_t, uint32_t) override {
851+
return llvm::createStringError(
852+
"RayTracing dispatchRays not yet supported on DirectX");
853+
}
854+
849855
void endEncodingImpl() override { popDebugGroup(); }
850856
};
851857

@@ -1532,6 +1538,20 @@ class DXDevice : public offloadtest::Device {
15321538
return std::make_unique<DXPipelineState>(Name, RootSig, PSO, std::nullopt);
15331539
}
15341540

1541+
llvm::Expected<std::unique_ptr<PipelineState>>
1542+
createPipelineRT(llvm::StringRef, const BindingsDesc &,
1543+
const RayTracingPipelineCreateDesc &) override {
1544+
return llvm::createStringError(
1545+
"RayTracing pipeline state not yet supported on DirectX");
1546+
}
1547+
1548+
llvm::Expected<std::unique_ptr<ShaderBindingTable>>
1549+
createShaderBindingTable(const PipelineState &,
1550+
const ShaderBindingTableDesc &) override {
1551+
return llvm::createStringError(
1552+
"RayTracing shader binding table not yet supported on DirectX");
1553+
}
1554+
15351555
llvm::Expected<std::unique_ptr<offloadtest::Fence>>
15361556
createFence(llvm::StringRef Name) override {
15371557
return DXFence::create(Device.Get(), Name);

lib/API/Device.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "API/Device.h"
1313
#include "API/Encoder.h"
1414
#include "API/FormatConversion.h"
15+
#include "API/ShaderBindingTable.h"
1516

1617
#include "Config.h"
1718

@@ -39,6 +40,56 @@ RenderPass::~RenderPass() {}
3940

4041
AccelerationStructure::~AccelerationStructure() {}
4142

43+
ShaderBindingTable::~ShaderBindingTable() {}
44+
45+
static uint32_t alignUp(uint32_t Value, uint32_t Alignment) {
46+
return (Value + Alignment - 1) & ~(Alignment - 1);
47+
}
48+
49+
SBTLayout offloadtest::computeSBTLayout(uint32_t IdentifierSize,
50+
uint32_t RecordAlign,
51+
uint32_t BaseAlign,
52+
const ShaderBindingTableDesc &Desc) {
53+
auto StrideFor = [&](llvm::ArrayRef<SBTEntry> Entries) {
54+
size_t MaxLocal = 0;
55+
for (const auto &E : Entries)
56+
MaxLocal = std::max<size_t>(MaxLocal, E.LocalRootData.size());
57+
return alignUp(IdentifierSize + static_cast<uint32_t>(MaxLocal),
58+
RecordAlign);
59+
};
60+
auto RegionSize = [&](uint32_t Count, uint32_t Stride) {
61+
return Count == 0 ? 0u : alignUp(Count * Stride, BaseAlign);
62+
};
63+
64+
// Vulkan dispatches exactly one raygen per vkCmdTraceRaysKHR and D3D12's
65+
// RayGenerationShaderRecord field is a single record; the descriptor only
66+
// carries one raygen entry.
67+
const llvm::ArrayRef<SBTEntry> RGEntries(&Desc.RayGen, 1);
68+
69+
SBTLayout L;
70+
L.RayGen.Stride = StrideFor(RGEntries);
71+
L.RayGen.Size = RegionSize(1, L.RayGen.Stride);
72+
L.RayGen.Offset = 0;
73+
74+
L.Miss.Stride = StrideFor(Desc.Miss);
75+
L.Miss.Size =
76+
RegionSize(static_cast<uint32_t>(Desc.Miss.size()), L.Miss.Stride);
77+
L.Miss.Offset = L.RayGen.Offset + L.RayGen.Size;
78+
79+
L.HitGroup.Stride = StrideFor(Desc.HitGroup);
80+
L.HitGroup.Size = RegionSize(static_cast<uint32_t>(Desc.HitGroup.size()),
81+
L.HitGroup.Stride);
82+
L.HitGroup.Offset = L.Miss.Offset + L.Miss.Size;
83+
84+
L.Callable.Stride = StrideFor(Desc.Callable);
85+
L.Callable.Size = RegionSize(static_cast<uint32_t>(Desc.Callable.size()),
86+
L.Callable.Stride);
87+
L.Callable.Offset = L.HitGroup.Offset + L.HitGroup.Size;
88+
89+
L.TotalSize = L.Callable.Offset + L.Callable.Size;
90+
return L;
91+
}
92+
4293
Device::~Device() {}
4394

4495
llvm::Expected<llvm::SmallVector<std::unique_ptr<Device>>>

lib/API/MTL/MTLDevice.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,12 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
625625
// MTL::Device handle (used to allocate scratch and instance buffers).
626626
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
627627

628+
llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &,
629+
uint32_t, uint32_t, uint32_t) override {
630+
return llvm::createStringError(
631+
"RayTracing dispatchRays not yet supported on Metal");
632+
}
633+
628634
/// Lazily transition into an AccelerationStructureCommandEncoder; mirrors
629635
/// the existing compute↔blit lazy switch.
630636
llvm::Error ensureASEncoder() {
@@ -1711,6 +1717,20 @@ class MTLDevice : public offloadtest::Device {
17111717

17121718
Queue &getGraphicsQueue() override { return GraphicsQueue; }
17131719

1720+
llvm::Expected<std::unique_ptr<PipelineState>>
1721+
createPipelineRT(llvm::StringRef, const BindingsDesc &,
1722+
const RayTracingPipelineCreateDesc &) override {
1723+
return llvm::createStringError(
1724+
"RayTracing pipeline state not yet supported on Metal");
1725+
}
1726+
1727+
llvm::Expected<std::unique_ptr<ShaderBindingTable>>
1728+
createShaderBindingTable(const PipelineState &,
1729+
const ShaderBindingTableDesc &) override {
1730+
return llvm::createStringError(
1731+
"RayTracing shader binding table not yet supported on Metal");
1732+
}
1733+
17141734
llvm::Expected<std::unique_ptr<offloadtest::Fence>>
17151735
createFence(llvm::StringRef Name) override {
17161736
return MTLFence::create(Device, Name);

0 commit comments

Comments
 (0)