Skip to content

Commit a23da10

Browse files
MarijnS95claude
andcommitted
[VK] Add ray tracing pipeline, SBT, and DispatchRays bring-up
First per-backend bring-up in the PSO raytracing series (#1268). Adds the API surface (ComputeEncoder::dispatchRays, Device::createPipelineRT, Device::createShaderBindingTable, RayTracingPipelineCreateDesc) plus the Vulkan implementation behind it. D3D12 and Metal stub the new methods with not-yet-supported errors; their bring-up lands in follow-up PRs. The pre-existing YAML schema struct from PR #1270 is renamed ShaderBindingTable -> ShaderBindingTableDesc so the bare name is free for the runtime resource class (parallel to BLASDesc / TLASDesc vs AccelerationStructure). A new include/API/ShaderBindingTable.h holds the abstract runtime base; concrete backend SBT classes derive from it with LLVM-style classof / cast<>. The VulkanDevice's prior `RaytracingFunctions RT` lumped AS and RT pipeline entry points together. They split into two structs — `ASFunctions AS` and `RTPipelineFunctions RT` — matching the actual feature-gate split (AS+ray-query is a complete configuration on its own, RT pipeline is layered on top). `HasRayTracingSupport` renames to `HasASSupport`, and a separate `HasRTPipelineSupport` tracks the new VK_KHR_ray_tracing_pipeline extension. Vulkan bring-up: - Extension: VK_KHR_ray_tracing_pipeline is requested when reported, with VkPhysicalDeviceRayTracingPipelineFeaturesKHR chained into the pre-create feature query. After the query the gating rayTracingPipeline bool is checked; capture-replay / trace-rays- indirect / traversal-primitive-culling sub-features are cleared since the tests don't exercise them. - Function pointers: vkCreateRayTracingPipelinesKHR, vkGetRayTracingShaderGroupHandlesKHR, vkCmdTraceRaysKHR. - Properties: VkPhysicalDeviceRayTracingPipelinePropertiesKHR is cached at device-create time for SBT handle size / alignment / base-alignment. - VKRayTracingPipelineState derives from VulkanPipelineState; an IsRayTracing flag on the base lets the existing Vulkan cast<> path stay polymorphic without adding a new GPUAPI value. classof tests both the API and the flag. The derived class also carries a StringMap<uint32_t> resolving each shader EntryPoint or HitGroup Name to its index in the pipeline's group array, plus per-bucket counts so the SBT builder can slice the contiguous handle blob into raygen / miss / hit / callable regions. - createPipelineRT builds a single VkShaderModule (the DXIL library compiles to one SPIR-V module with multiple OpEntryPoints), then one VkPipelineShaderStageCreateInfo per Shader entry and one VkRayTracingShaderGroupCreateInfoKHR per general shader / hit group. Pipeline layout is shared with the compute path via createPipelineLayout, gated on all six RT stage flags so any binding can be consumed from any RT shader. - createShaderBindingTable allocates a host-visible coherent buffer big enough for four regions and lays out each entry as [handle bytes][localRootData bytes][padding-to-stride]. Per-region stride = align(handleSize + max-local-root-data-in-region, handleAlignment); per-region size = align(count * stride, baseAlignment). LocalRootData support comes free from the PR1 SBT schema; the test doesn't exercise it yet. Each region's VkStridedDeviceAddressRegionKHR derives from the buffer's vkGetBufferDeviceAddress. - dispatchRays binds the pipeline at VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR, emits a pre-barrier with AS_READ + SHADER_READ/WRITE dst access into RAY_TRACING_SHADER_BIT_KHR, then calls vkCmdTraceRaysKHR with the SBT's four region structs. - createCommands picks the new bind point for RT pipelines so vkCmdBindDescriptorSets binds to the right point. executeProgram's isRayTracing branch builds a RayTracingPipelineCreateDesc from the YAML, calls createPipelineRT then createShaderBindingTable, and keeps both on InvocationState for the dispatch. raygen-roundtrip.test now expects DirectX/Metal/Clang to XFAIL; on a DXC + Vulkan combo with VK_KHR_ray_tracing_pipeline supported the test should PASS via this implementation. On the user's Linux + clang-dxc loop the test still XFAILs because clang-dxc doesn't yet lower [shader("raygeneration")] entry points to SPIR-V, so the Clang XFAIL token catches the compile failure. CI on a working DXC install will exercise the runtime path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 22639c9 commit a23da10

10 files changed

Lines changed: 755 additions & 40 deletions

File tree

include/API/Device.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "API/Capabilities.h"
2121
#include "API/CommandBuffer.h"
2222
#include "API/RenderPass.h"
23+
#include "API/ShaderBindingTable.h"
2324
#include "API/Texture.h"
2425

2526
#include "Support/Pipeline.h"
@@ -84,6 +85,21 @@ struct ShaderContainer {
8485
llvm::SmallVector<SpecializationConstant> SpecializationConstants;
8586
};
8687

88+
struct RayTracingShader {
89+
Stages Stage;
90+
std::string EntryPoint;
91+
};
92+
93+
struct RayTracingPipelineCreateDesc {
94+
// All RT shaders are compiled into a single DXIL library; every entry in
95+
// `Shaders` references this same blob via the backend's library-loading
96+
// path.
97+
const llvm::MemoryBuffer *Library = nullptr;
98+
llvm::SmallVector<RayTracingShader> Shaders;
99+
llvm::SmallVector<HitGroup> HitGroups;
100+
RayTracingPipelineConfig Config;
101+
};
102+
87103
struct TraditionalRasterPipelineCreateDesc {
88104
llvm::SmallVector<InputLayoutDesc> InputLayout;
89105
llvm::SmallVector<Format> RTFormats;
@@ -259,6 +275,14 @@ class Device {
259275
llvm::StringRef Name, const BindingsDesc &BindingsDesc,
260276
const MeshShaderRasterPipelineCreateDesc &Desc) = 0;
261277

278+
virtual llvm::Expected<std::unique_ptr<PipelineState>>
279+
createPipelineRT(llvm::StringRef Name, const BindingsDesc &BindingsDesc,
280+
const RayTracingPipelineCreateDesc &Desc) = 0;
281+
282+
virtual llvm::Expected<std::unique_ptr<ShaderBindingTable>>
283+
createShaderBindingTable(const PipelineState &PSO,
284+
const ShaderBindingTableDesc &Desc) = 0;
285+
262286
virtual llvm::Expected<std::unique_ptr<Fence>>
263287
createFence(llvm::StringRef Name) = 0;
264288

include/API/Encoder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ namespace offloadtest {
2424
class Buffer;
2525
class PipelineState;
2626
class AccelerationStructure;
27+
class ShaderBindingTable;
2728
struct BLASBuildRequest;
2829
struct TLASBuildRequest;
2930

@@ -104,6 +105,17 @@ class ComputeEncoder : public CommandEncoder {
104105
/// Metal). A barrier covering AS-build writes is implicitly emitted before
105106
/// any subsequent command that reads from the freshly-built structures.
106107
virtual llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) = 0;
108+
109+
/// Trace rays from a RayTracing pipeline. \p PSO must have been created via
110+
/// Device::createPipelineRT and \p SBT via Device::createShaderBindingTable
111+
/// on that same PSO. \p Width, \p Height, \p Depth are the dispatch
112+
/// dimensions passed through to the backend's DispatchRays equivalent
113+
/// (D3D12 DispatchRays, Vulkan vkCmdTraceRaysKHR, Metal compute dispatch
114+
/// after metal_irconverter lowering).
115+
virtual llvm::Error dispatchRays(const PipelineState &PSO,
116+
const ShaderBindingTable &SBT,
117+
uint32_t Width, uint32_t Height,
118+
uint32_t Depth) = 0;
107119
};
108120

109121
struct Viewport {

include/API/ShaderBindingTable.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//===- ShaderBindingTable.h - Offload RT Shader Binding Table -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef OFFLOADTEST_API_SHADERBINDINGTABLE_H
10+
#define OFFLOADTEST_API_SHADERBINDINGTABLE_H
11+
12+
#include "API/API.h"
13+
14+
#include <cstdint>
15+
16+
namespace offloadtest {
17+
18+
struct ShaderBindingTableDesc;
19+
20+
/// Runtime shader binding table built from a RayTracing PipelineState plus a
21+
/// ShaderBindingTableDesc. Concrete subclasses (one per backend) hold the
22+
/// device-side records and any address ranges needed by the backend's
23+
/// DispatchRays call.
24+
class ShaderBindingTable {
25+
GPUAPI API;
26+
27+
public:
28+
virtual ~ShaderBindingTable();
29+
ShaderBindingTable(const ShaderBindingTable &) = delete;
30+
ShaderBindingTable &operator=(const ShaderBindingTable &) = delete;
31+
32+
GPUAPI getAPI() const { return API; }
33+
34+
protected:
35+
explicit ShaderBindingTable(GPUAPI API) : API(API) {}
36+
};
37+
38+
/// Per-region SBT layout numbers.
39+
///
40+
/// Every backend lays an SBT out as four concatenated regions (raygen, miss,
41+
/// hit-group, callable). Within a region every record is
42+
/// `[shader-identifier][LocalRootData][padding-to-stride]`, where stride is
43+
/// `align(identifierSize + max-LocalRootData-in-region, RecordAlign)` and
44+
/// the region itself is aligned to `BaseAlign`. The numbers match the
45+
/// alignment rules of both D3D12
46+
/// (`D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES` / `…_RECORD_BYTE_ALIGNMENT` /
47+
/// `…_TABLE_BYTE_ALIGNMENT`) and Vulkan
48+
/// (`shaderGroupHandleSize` / `shaderGroupHandleAlignment` /
49+
/// `shaderGroupBaseAlignment`); backend-specific constants are passed in.
50+
struct SBTRegionLayout {
51+
uint32_t Stride = 0;
52+
uint32_t Size = 0;
53+
uint32_t Offset = 0; // byte offset from the start of the SBT buffer
54+
};
55+
56+
struct SBTLayout {
57+
SBTRegionLayout RayGen;
58+
SBTRegionLayout Miss;
59+
SBTRegionLayout HitGroup;
60+
SBTRegionLayout Callable;
61+
uint32_t TotalSize = 0;
62+
};
63+
64+
/// Compute the per-region layout for an SBT description.
65+
///
66+
/// \p IdentifierSize is the size of one shader identifier (32 bytes on
67+
/// D3D12; `shaderGroupHandleSize` on Vulkan).
68+
/// \p RecordAlign is the per-record alignment (D3D12: 32 bytes; Vulkan:
69+
/// `shaderGroupHandleAlignment`).
70+
/// \p BaseAlign is the per-region alignment (D3D12: 64 bytes; Vulkan:
71+
/// `shaderGroupBaseAlignment`).
72+
SBTLayout computeSBTLayout(uint32_t IdentifierSize, uint32_t RecordAlign,
73+
uint32_t BaseAlign,
74+
const ShaderBindingTableDesc &Desc);
75+
76+
} // namespace offloadtest
77+
78+
#endif // OFFLOADTEST_API_SHADERBINDINGTABLE_H

include/Support/Pipeline.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,7 @@ struct SBTEntry {
592592
llvm::SmallVector<uint8_t> LocalRootData;
593593
};
594594

595-
struct ShaderBindingTable {
595+
struct ShaderBindingTableDesc {
596596
SBTEntry RayGen;
597597
llvm::SmallVector<SBTEntry> Miss;
598598
llvm::SmallVector<SBTEntry> HitGroup;
@@ -614,7 +614,7 @@ struct Pipeline {
614614
AccelerationStructureDescs AccelStructs;
615615
std::optional<RayTracingPipelineConfig> RTConfig;
616616
llvm::SmallVector<HitGroup> HitGroups;
617-
std::optional<ShaderBindingTable> SBT;
617+
std::optional<ShaderBindingTableDesc> SBT;
618618

619619
uint32_t getVertexCount() const {
620620
if (DispatchParameters.VertexCount)
@@ -825,8 +825,8 @@ template <> struct MappingTraits<offloadtest::SBTEntry> {
825825
static void mapping(IO &I, offloadtest::SBTEntry &E);
826826
};
827827

828-
template <> struct MappingTraits<offloadtest::ShaderBindingTable> {
829-
static void mapping(IO &I, offloadtest::ShaderBindingTable &S);
828+
template <> struct MappingTraits<offloadtest::ShaderBindingTableDesc> {
829+
static void mapping(IO &I, offloadtest::ShaderBindingTableDesc &S);
830830
};
831831

832832
template <> struct ScalarEnumerationTraits<offloadtest::Rule> {

lib/API/DX/Device.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,12 @@ class DXComputeEncoder : public offloadtest::ComputeEncoder {
795795
// ID3D12Device5 entry point and helper allocators.
796796
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
797797

798+
llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &,
799+
uint32_t, uint32_t, uint32_t) override {
800+
return llvm::createStringError(
801+
"RayTracing dispatchRays not yet supported on DirectX");
802+
}
803+
798804
void endEncodingImpl() override { popDebugGroup(); }
799805
};
800806

@@ -1478,6 +1484,20 @@ class DXDevice : public offloadtest::Device {
14781484
return std::make_unique<DXPipelineState>(Name, RootSig, PSO, std::nullopt);
14791485
}
14801486

1487+
llvm::Expected<std::unique_ptr<PipelineState>>
1488+
createPipelineRT(llvm::StringRef, const BindingsDesc &,
1489+
const RayTracingPipelineCreateDesc &) override {
1490+
return llvm::createStringError(
1491+
"RayTracing pipeline state not yet supported on DirectX");
1492+
}
1493+
1494+
llvm::Expected<std::unique_ptr<ShaderBindingTable>>
1495+
createShaderBindingTable(const PipelineState &,
1496+
const ShaderBindingTableDesc &) override {
1497+
return llvm::createStringError(
1498+
"RayTracing shader binding table not yet supported on DirectX");
1499+
}
1500+
14811501
llvm::Expected<std::unique_ptr<offloadtest::Fence>>
14821502
createFence(llvm::StringRef Name) override {
14831503
return DXFence::create(Device.Get(), Name);

lib/API/Device.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "API/Device.h"
1313
#include "API/Encoder.h"
1414
#include "API/FormatConversion.h"
15+
#include "API/ShaderBindingTable.h"
1516

1617
#include "Config.h"
1718

@@ -39,6 +40,56 @@ RenderPass::~RenderPass() {}
3940

4041
AccelerationStructure::~AccelerationStructure() {}
4142

43+
ShaderBindingTable::~ShaderBindingTable() {}
44+
45+
static uint32_t alignUp(uint32_t Value, uint32_t Alignment) {
46+
return (Value + Alignment - 1) & ~(Alignment - 1);
47+
}
48+
49+
SBTLayout offloadtest::computeSBTLayout(uint32_t IdentifierSize,
50+
uint32_t RecordAlign,
51+
uint32_t BaseAlign,
52+
const ShaderBindingTableDesc &Desc) {
53+
auto StrideFor = [&](llvm::ArrayRef<SBTEntry> Entries) {
54+
size_t MaxLocal = 0;
55+
for (const auto &E : Entries)
56+
MaxLocal = std::max<size_t>(MaxLocal, E.LocalRootData.size());
57+
return alignUp(IdentifierSize + static_cast<uint32_t>(MaxLocal),
58+
RecordAlign);
59+
};
60+
auto RegionSize = [&](uint32_t Count, uint32_t Stride) {
61+
return Count == 0 ? 0u : alignUp(Count * Stride, BaseAlign);
62+
};
63+
64+
// Vulkan dispatches exactly one raygen per vkCmdTraceRaysKHR and D3D12's
65+
// RayGenerationShaderRecord field is a single record; the descriptor only
66+
// carries one raygen entry.
67+
const llvm::ArrayRef<SBTEntry> RGEntries(&Desc.RayGen, 1);
68+
69+
SBTLayout L;
70+
L.RayGen.Stride = StrideFor(RGEntries);
71+
L.RayGen.Size = RegionSize(1, L.RayGen.Stride);
72+
L.RayGen.Offset = 0;
73+
74+
L.Miss.Stride = StrideFor(Desc.Miss);
75+
L.Miss.Size =
76+
RegionSize(static_cast<uint32_t>(Desc.Miss.size()), L.Miss.Stride);
77+
L.Miss.Offset = L.RayGen.Offset + L.RayGen.Size;
78+
79+
L.HitGroup.Stride = StrideFor(Desc.HitGroup);
80+
L.HitGroup.Size = RegionSize(static_cast<uint32_t>(Desc.HitGroup.size()),
81+
L.HitGroup.Stride);
82+
L.HitGroup.Offset = L.Miss.Offset + L.Miss.Size;
83+
84+
L.Callable.Stride = StrideFor(Desc.Callable);
85+
L.Callable.Size = RegionSize(static_cast<uint32_t>(Desc.Callable.size()),
86+
L.Callable.Stride);
87+
L.Callable.Offset = L.HitGroup.Offset + L.HitGroup.Size;
88+
89+
L.TotalSize = L.Callable.Offset + L.Callable.Size;
90+
return L;
91+
}
92+
4293
Device::~Device() {}
4394

4495
llvm::Expected<llvm::SmallVector<std::unique_ptr<Device>>>

lib/API/MTL/MTLDevice.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,12 @@ class MTLComputeEncoder : public offloadtest::ComputeEncoder {
598598
// MTL::Device handle (used to allocate scratch and instance buffers).
599599
llvm::Error batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) override;
600600

601+
llvm::Error dispatchRays(const PipelineState &, const ShaderBindingTable &,
602+
uint32_t, uint32_t, uint32_t) override {
603+
return llvm::createStringError(
604+
"RayTracing dispatchRays not yet supported on Metal");
605+
}
606+
601607
/// Lazily transition into an AccelerationStructureCommandEncoder; mirrors
602608
/// the existing compute↔blit lazy switch.
603609
llvm::Error ensureASEncoder() {
@@ -1678,6 +1684,20 @@ class MTLDevice : public offloadtest::Device {
16781684

16791685
Queue &getGraphicsQueue() override { return GraphicsQueue; }
16801686

1687+
llvm::Expected<std::unique_ptr<PipelineState>>
1688+
createPipelineRT(llvm::StringRef, const BindingsDesc &,
1689+
const RayTracingPipelineCreateDesc &) override {
1690+
return llvm::createStringError(
1691+
"RayTracing pipeline state not yet supported on Metal");
1692+
}
1693+
1694+
llvm::Expected<std::unique_ptr<ShaderBindingTable>>
1695+
createShaderBindingTable(const PipelineState &,
1696+
const ShaderBindingTableDesc &) override {
1697+
return llvm::createStringError(
1698+
"RayTracing shader binding table not yet supported on Metal");
1699+
}
1700+
16811701
llvm::Expected<std::unique_ptr<offloadtest::Fence>>
16821702
createFence(llvm::StringRef Name) override {
16831703
return MTLFence::create(Device, Name);

0 commit comments

Comments
 (0)