Skip to content

Commit 36c89b6

Browse files
MarijnS95claude
andcommitted
[VK] Add ray tracing pipeline, SBT, and DispatchRays bring-up
First per-backend bring-up in the PSO raytracing series (#1268). Adds the API surface (ComputeEncoder::dispatchRays, Device::createPipelineRT, Device::createShaderBindingTable, RayTracingPipelineCreateDesc) plus the Vulkan implementation behind it. D3D12 and Metal stub the new methods with not-yet-supported errors; their bring-up lands in follow-up PRs. The pre-existing YAML schema struct from PR #1270 is renamed ShaderBindingTable -> ShaderBindingTableDesc so the bare name is free for the runtime resource class (parallel to BLASDesc / TLASDesc vs AccelerationStructure). A new include/API/ShaderBindingTable.h holds the abstract runtime base; concrete backend SBT classes derive from it with LLVM-style classof / cast<>. The VulkanDevice's prior `RaytracingFunctions RT` lumped AS and RT pipeline entry points together. They split into two structs — `ASFunctions AS` and `RTPipelineFunctions RT` — matching the actual feature-gate split (AS+ray-query is a complete configuration on its own, RT pipeline is layered on top). `HasRayTracingSupport` renames to `HasASSupport`, and a separate `HasRTPipelineSupport` tracks the new VK_KHR_ray_tracing_pipeline extension. Vulkan bring-up: - Extension: VK_KHR_ray_tracing_pipeline is requested when reported, with VkPhysicalDeviceRayTracingPipelineFeaturesKHR chained into the pre-create feature query. After the query the gating rayTracingPipeline bool is checked; capture-replay / trace-rays- indirect / traversal-primitive-culling sub-features are cleared since the tests don't exercise them. - Function pointers: vkCreateRayTracingPipelinesKHR, vkGetRayTracingShaderGroupHandlesKHR, vkCmdTraceRaysKHR. - Properties: VkPhysicalDeviceRayTracingPipelinePropertiesKHR is cached at device-create time for SBT handle size / alignment / base-alignment. - VKRayTracingPipelineState derives from VulkanPipelineState; an IsRayTracing flag on the base lets the existing Vulkan cast<> path stay polymorphic without adding a new GPUAPI value. classof tests both the API and the flag. The derived class also carries a StringMap<uint32_t> resolving each shader EntryPoint or HitGroup Name to its index in the pipeline's group array, plus per-bucket counts so the SBT builder can slice the contiguous handle blob into raygen / miss / hit / callable regions. - createPipelineRT builds a single VkShaderModule (the DXIL library compiles to one SPIR-V module with multiple OpEntryPoints), then one VkPipelineShaderStageCreateInfo per Shader entry and one VkRayTracingShaderGroupCreateInfoKHR per general shader / hit group. Pipeline layout is shared with the compute path via createPipelineLayout, gated on all six RT stage flags so any binding can be consumed from any RT shader. - createShaderBindingTable allocates a host-visible coherent buffer big enough for four regions and lays out each entry as [handle bytes][localRootData bytes][padding-to-stride]. Per-region stride = align(handleSize + max-local-root-data-in-region, handleAlignment); per-region size = align(count * stride, baseAlignment). LocalRootData support comes free from the PR1 SBT schema; the test doesn't exercise it yet. Each region's VkStridedDeviceAddressRegionKHR derives from the buffer's vkGetBufferDeviceAddress. - dispatchRays binds the pipeline at VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, emits a pre-barrier with AS_READ + SHADER_READ/WRITE dst access into RAY_TRACING_SHADER_BIT_KHR, then calls vkCmdTraceRaysKHR with the SBT's four region structs. - createCommands picks the new bind point for RT pipelines so vkCmdBindDescriptorSets binds to the right point. executeProgram's isRayTracing branch builds a RayTracingPipelineCreateDesc from the YAML, calls createPipelineRT then createShaderBindingTable, and keeps both on InvocationState for the dispatch. raygen-roundtrip.test now expects DirectX/Metal/Clang to XFAIL; on a DXC + Vulkan combo with VK_KHR_ray_tracing_pipeline supported the test should PASS via this implementation. On the user's Linux + clang-dxc loop the test still XFAILs because clang-dxc doesn't yet lower [shader("raygeneration")] entry points to SPIR-V, so the Clang XFAIL token catches the compile failure. CI on a working DXC install will exercise the runtime path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4dffff2 commit 36c89b6

10 files changed

Lines changed: 761 additions & 46 deletions

File tree

include/API/Device.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "API/Capabilities.h"
2121
#include "API/CommandBuffer.h"
2222
#include "API/RenderPass.h"
23+
#include "API/ShaderBindingTable.h"
2324
#include "API/Texture.h"
2425

2526
#include "Support/Pipeline.h"
@@ -83,6 +84,21 @@ struct ShaderContainer {
8384
llvm::SmallVector<SpecializationConstant> SpecializationConstants;
8485
};
8586

87+
struct RayTracingShader {
88+
Stages Stage;
89+
std::string EntryPoint;
90+
};
91+
92+
struct RayTracingPipelineCreateDesc {
93+
// All RT shaders are compiled into a single DXIL library; every entry in
94+
// `Shaders` references this same blob via the backend's library-loading
95+
// path.
96+
const llvm::MemoryBuffer *Library = nullptr;
97+
llvm::SmallVector<RayTracingShader> Shaders;
98+
llvm::SmallVector<HitGroup> HitGroups;
99+
RayTracingPipelineConfig Config;
100+
};
101+
86102
struct TraditionalRasterPipelineCreateDesc {
87103
llvm::SmallVector<InputLayoutDesc> InputLayout;
88104
llvm::SmallVector<Format> RTFormats;
@@ -215,6 +231,14 @@ class Device {
215231
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
216232
const TraditionalRasterPipelineCreateDesc &Desc) = 0;
217233

234+
virtual llvm::Expected<std::unique_ptr<PipelineState>>
235+
createPipelineRT(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
236+
const RayTracingPipelineCreateDesc &Desc) = 0;
237+
238+
virtual llvm::Expected<std::unique_ptr<ShaderBindingTable>>
239+
createShaderBindingTable(const PipelineState &PSO,
240+
const ShaderBindingTableDesc &Desc) = 0;
241+
218242
virtual llvm::Expected<std::unique_ptr<Fence>>
219243
createFence(llvm::StringRef Name) = 0;
220244

include/API/Encoder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace offloadtest {
2424
class Buffer;
2525
class PipelineState;
2626
class AccelerationStructure;
27+
class ShaderBindingTable;
2728
struct BLASBuildRequest;
2829
struct TLASBuildRequest;
2930

@@ -105,6 +106,17 @@ class ComputeEncoder : public CommandEncoder {
105106
/// Metal). A barrier covering AS-build writes is implicitly emitted before
106107
/// any subsequent command that reads from the freshly-built structures.
107108
virtual llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) = 0;
109+
110+
/// Trace rays from a RayTracing pipeline. \p PSO must have been created via
111+
/// Device::createPipelineRT and \p SBT via Device::createShaderBindingTable
112+
/// on that same PSO. \p Width, \p Height, \p Depth are the dispatch
113+
/// dimensions passed through to the backend's DispatchRays equivalent
114+
/// (D3D12 DispatchRays, Vulkan vkCmdTraceRaysKHR, Metal compute dispatch
115+
/// after metal_irconverter lowering).
116+
virtual llvm::Error dispatchRays(const PipelineState &PSO,
117+
const ShaderBindingTable &SBT,
118+
uint32_t Width, uint32_t Height,
119+
uint32_t Depth) = 0;
108120
};
109121

110122
struct Viewport {

include/API/ShaderBindingTable.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//===- ShaderBindingTable.h - Offload RT Shader Binding Table -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef OFFLOADTEST_API_SHADERBINDINGTABLE_H
10+
#define OFFLOADTEST_API_SHADERBINDINGTABLE_H
11+
12+
#include "API/API.h"
13+
14+
#include <cstdint>
15+
16+
namespace offloadtest {
17+
18+
struct ShaderBindingTableDesc;
19+
20+
/// Runtime shader binding table built from a RayTracing PipelineState plus a
21+
/// ShaderBindingTableDesc. Concrete subclasses (one per backend) hold the
22+
/// device-side records and any address ranges needed by the backend's
23+
/// DispatchRays call.
24+
class ShaderBindingTable {
25+
GPUAPI API;
26+
27+
public:
28+
virtual ~ShaderBindingTable();
29+
ShaderBindingTable(const ShaderBindingTable &) = delete;
30+
ShaderBindingTable &operator=(const ShaderBindingTable &) = delete;
31+
32+
GPUAPI getAPI() const { return API; }
33+
34+
protected:
35+
explicit ShaderBindingTable(GPUAPI API) : API(API) {}
36+
};
37+
38+
/// Per-region SBT layout numbers.
39+
///
40+
/// Every backend lays an SBT out as four concatenated regions (raygen, miss,
41+
/// hit-group, callable). Within a region every record is
42+
/// `[shader-identifier][LocalRootData][padding-to-stride]`, where stride is
43+
/// `align(identifierSize + max-LocalRootData-in-region, RecordAlign)` and
44+
/// the region itself is aligned to `BaseAlign`. The numbers match the
45+
/// alignment rules of both D3D12
46+
/// (`D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES` / `…_RECORD_BYTE_ALIGNMENT` /
47+
/// `…_TABLE_BYTE_ALIGNMENT`) and Vulkan
48+
/// (`shaderGroupHandleSize` / `shaderGroupHandleAlignment` /
49+
/// `shaderGroupBaseAlignment`); backend-specific constants are passed in.
50+
struct SBTRegionLayout {
51+
uint32_t Stride = 0;
52+
uint32_t Size = 0;
53+
uint32_t Offset = 0; // byte offset from the start of the SBT buffer
54+
};
55+
56+
struct SBTLayout {
57+
SBTRegionLayout RayGen;
58+
SBTRegionLayout Miss;
59+
SBTRegionLayout HitGroup;
60+
SBTRegionLayout Callable;
61+
uint32_t TotalSize = 0;
62+
};
63+
64+
/// Compute the per-region layout for an SBT description.
65+
///
66+
/// \p IdentifierSize is the size of one shader identifier (32 bytes on
67+
/// D3D12; `shaderGroupHandleSize` on Vulkan).
68+
/// \p RecordAlign is the per-record alignment (D3D12: 32 bytes; Vulkan:
69+
/// `shaderGroupHandleAlignment`).
70+
/// \p BaseAlign is the per-region alignment (D3D12: 64 bytes; Vulkan:
71+
/// `shaderGroupBaseAlignment`).
72+
SBTLayout computeSBTLayout(uint32_t IdentifierSize, uint32_t RecordAlign,
73+
uint32_t BaseAlign,
74+
const ShaderBindingTableDesc &Desc);
75+
76+
} // namespace offloadtest
77+
78+
#endif // OFFLOADTEST_API_SHADERBINDINGTABLE_H

include/Support/Pipeline.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ struct SBTEntry {
592592
llvm::SmallVector<uint8_t> LocalRootData;
593593
};
594594

595-
struct ShaderBindingTable {
595+
struct ShaderBindingTableDesc {
596596
SBTEntry RayGen;
597597
llvm::SmallVector<SBTEntry> Miss;
598598
llvm::SmallVector<SBTEntry> HitGroup;
@@ -614,7 +614,7 @@ struct Pipeline {
614614
AccelerationStructureDescs AccelStructs;
615615
std::optional<RayTracingPipelineConfig> RTConfig;
616616
llvm::SmallVector<HitGroup> HitGroups;
617-
std::optional<ShaderBindingTable> SBT;
617+
std::optional<ShaderBindingTableDesc> SBT;
618618

619619
uint32_t getVertexCount() const {
620620
if (DispatchParameters.VertexCount)
@@ -825,8 +825,8 @@ template <> struct MappingTraits<offloadtest::SBTEntry> {
825825
static void mapping(IO &I, offloadtest::SBTEntry &E);
826826
};
827827

828-
template <> struct MappingTraits<offloadtest::ShaderBindingTable> {
829-
static void mapping(IO &I, offloadtest::ShaderBindingTable &S);
828+
template <> struct MappingTraits<offloadtest::ShaderBindingTableDesc> {
829+
static void mapping(IO &I, offloadtest::ShaderBindingTableDesc &S);
830830
};
831831

832832
template <> struct ScalarEnumerationTraits<offloadtest::Rule> {

lib/API/DX/Device.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,12 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
730730
// ID3D12Device5 entry point and helper allocators.
731731
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
732732

733+
llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &,
734+
uint32_t, uint32_t, uint32_t) override {
735+
return llvm::createStringError(
736+
"RayTracing dispatchRays not yet supported on DirectX");
737+
}
738+
733739
void endEncodingImpl() override { popDebugGroup(); }
734740
};
735741

@@ -1356,6 +1362,20 @@ class DXDevice : public offloadtest::Device {
13561362
return std::make_unique<DXPipelineState>(Name, RootSig, PSO, std::nullopt);
13571363
}
13581364

1365+
llvm::Expected<std::unique_ptr<PipelineState>>
1366+
createPipelineRT(llvm::StringRef, const BindingsDesc &,
1367+
const RayTracingPipelineCreateDesc &) override {
1368+
return llvm::createStringError(
1369+
"RayTracing pipeline state not yet supported on DirectX");
1370+
}
1371+
1372+
llvm::Expected<std::unique_ptr<ShaderBindingTable>>
1373+
createShaderBindingTable(const PipelineState &,
1374+
const ShaderBindingTableDesc &) override {
1375+
return llvm::createStringError(
1376+
"RayTracing shader binding table not yet supported on DirectX");
1377+
}
1378+
13591379
llvm::Expected<std::unique_ptr<offloadtest::Fence>>
13601380
createFence(llvm::StringRef Name) override {
13611381
return DXFence::create(Device.Get(), Name);

lib/API/Device.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "API/Device.h"
1313
#include "API/Encoder.h"
1414
#include "API/FormatConversion.h"
15+
#include "API/ShaderBindingTable.h"
1516

1617
#include "Config.h"
1718

@@ -39,6 +40,56 @@ RenderPass::~RenderPass() {}
3940

4041
AccelerationStructure::~AccelerationStructure() {}
4142

43+
ShaderBindingTable::~ShaderBindingTable() {}
44+
45+
static uint32_t alignUp(uint32_t Value, uint32_t Alignment) {
46+
return (Value + Alignment - 1) & ~(Alignment - 1);
47+
}
48+
49+
SBTLayout offloadtest::computeSBTLayout(uint32_t IdentifierSize,
50+
uint32_t RecordAlign,
51+
uint32_t BaseAlign,
52+
const ShaderBindingTableDesc &Desc) {
53+
auto StrideFor = [&](llvm::ArrayRef<SBTEntry> Entries) {
54+
size_t MaxLocal = 0;
55+
for (const auto &E : Entries)
56+
MaxLocal = std::max<size_t>(MaxLocal, E.LocalRootData.size());
57+
return alignUp(IdentifierSize + static_cast<uint32_t>(MaxLocal),
58+
RecordAlign);
59+
};
60+
auto RegionSize = [&](uint32_t Count, uint32_t Stride) {
61+
return Count == 0 ? 0u : alignUp(Count * Stride, BaseAlign);
62+
};
63+
64+
// Vulkan dispatches exactly one raygen per vkCmdTraceRaysKHR and D3D12's
65+
// RayGenerationShaderRecord field is a single record; the descriptor only
66+
// carries one raygen entry.
67+
const llvm::ArrayRef<SBTEntry> RGEntries(&Desc.RayGen, 1);
68+
69+
SBTLayout L;
70+
L.RayGen.Stride = StrideFor(RGEntries);
71+
L.RayGen.Size = RegionSize(1, L.RayGen.Stride);
72+
L.RayGen.Offset = 0;
73+
74+
L.Miss.Stride = StrideFor(Desc.Miss);
75+
L.Miss.Size =
76+
RegionSize(static_cast<uint32_t>(Desc.Miss.size()), L.Miss.Stride);
77+
L.Miss.Offset = L.RayGen.Offset + L.RayGen.Size;
78+
79+
L.HitGroup.Stride = StrideFor(Desc.HitGroup);
80+
L.HitGroup.Size = RegionSize(static_cast<uint32_t>(Desc.HitGroup.size()),
81+
L.HitGroup.Stride);
82+
L.HitGroup.Offset = L.Miss.Offset + L.Miss.Size;
83+
84+
L.Callable.Stride = StrideFor(Desc.Callable);
85+
L.Callable.Size = RegionSize(static_cast<uint32_t>(Desc.Callable.size()),
86+
L.Callable.Stride);
87+
L.Callable.Offset = L.HitGroup.Offset + L.HitGroup.Size;
88+
89+
L.TotalSize = L.Callable.Offset + L.Callable.Size;
90+
return L;
91+
}
92+
4293
Device::~Device() {}
4394

4495
llvm::Expected<llvm::SmallVector<std::unique_ptr<Device>>>

lib/API/MTL/MTLDevice.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
602602
// MTL::Device handle (used to allocate scratch and instance buffers).
603603
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
604604

605+
llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &,
606+
uint32_t, uint32_t, uint32_t) override {
607+
return llvm::createStringError(
608+
"RayTracing dispatchRays not yet supported on Metal");
609+
}
610+
605611
/// Lazily transition into an AccelerationStructureCommandEncoder; mirrors
606612
/// the existing compute↔blit lazy switch.
607613
llvm::Error ensureASEncoder() {
@@ -1662,6 +1668,20 @@ class MTLDevice : public offloadtest::Device {
16621668

16631669
Queue &getGraphicsQueue() override { return GraphicsQueue; }
16641670

1671+
llvm::Expected<std::unique_ptr<PipelineState>>
1672+
createPipelineRT(llvm::StringRef, const BindingsDesc &,
1673+
const RayTracingPipelineCreateDesc &) override {
1674+
return llvm::createStringError(
1675+
"RayTracing pipeline state not yet supported on Metal");
1676+
}
1677+
1678+
llvm::Expected<std::unique_ptr<ShaderBindingTable>>
1679+
createShaderBindingTable(const PipelineState &,
1680+
const ShaderBindingTableDesc &) override {
1681+
return llvm::createStringError(
1682+
"RayTracing shader binding table not yet supported on Metal");
1683+
}
1684+
16651685
llvm::Expected<std::unique_ptr<offloadtest::Fence>>
16661686
createFence(llvm::StringRef Name) override {
16671687
return MTLFence::create(Device, Name);

0 commit comments

Comments
 (0)