Skip to content

Commit 1ddd4d9

Browse files
MarijnS95claudeEmilioLaiso
authored
Add RayTracing pipeline kind, shader stages, and YAML schema (llvm#1270)
Depends on llvm#1245 ## Summary Foundational PR in the PSO-based raytracing bring-up series tracked in llvm#1268. Stacks on top of llvm#1245 (which depends on llvm#1244, which depends on llvm#1232) — only the top commit on this branch is new; the rest are the inline-RT bring-up already in review. Lays out the framework-side surface needed by the upcoming backend PRs: - `ShaderPipelineKind::RayTracing` plus six new `Stages` — `RayGeneration`, `Miss`, `ClosestHit`, `AnyHit`, `Intersection`, `Callable` — with `isRayTracingStage` / `Pipeline::isRayTracing()` helpers. - YAML schema for an RT pipeline: `HitGroup` (Triangles | Procedural, ClosestHit + optional AnyHit / Intersection), `RayTracingPipelineConfig` (MaxTraceRecursionDepth, MaxPayloadSizeInBytes, MaxAttributeSizeInBytes, optional PipelineFlags), and `ShaderBindingTable` (raygen / miss / hit-group / callable records, each with optional reserved LocalRootData bytes). - `validatePipelineKind` allows duplicate RT stages (a pipeline can have several miss / hit-group shaders, which the existing duplicate check would have rejected), requires at least one RayGeneration, and rejects mixing with Compute/Vertex/Mesh. The reverse check rejects HitGroups / RTConfig / SBT on any non-RT pipeline. `validateDispatchParameters` reinterprets `DispatchGroupCount` as `{Width, Height, Depth}` for the upcoming DispatchRays and forbids VertexCount on RT. - Existing `Stages` switches across the backends grow the six RT cases — Vulkan maps each one to its `VK_SHADER_STAGE_*_KHR` bit ready for PR 2; Metal unreachables on RT (`metal_irconverter` takes a different route); raster pipeline `setShader` (Traditional + MeshShader variants) adds them to the existing unreachable group. - Each backend's `executeProgram` gets a terminal `else if (P.isRayTracing())` that returns a "not yet supported on <backend>" error so PR2/3/4 just have to replace it. - `%dxc_target_lib` lit substitution (same compiler binary, separate name for `-T lib_6_x` library targets); `raytracing-pipeline` available-feature gated on DX `RaytracingTier >= 1.0` and the Vulkan `VK_KHR_ray_tracing_pipeline` extension being reported by the device. - Foundational `test/Feature/RT/raygen-roundtrip.test` exercising the full schema (raygen+miss+CH, BLAS/TLAS, HitGroups, RTConfig, SBT). Gated on `raytracing-pipeline` and `XFAIL: *` until each backend bring-up lands. ## Test plan Local on an NVIDIA RTX 3060: - [x] Linux Vulkan (native `offloader`) - [ ] Linux D3D12 (Wine + vkd3d-proton + cross-compiled `offloader.exe`) - [ ] Windows Vulkan (native `offloader.exe`) - [ ] Windows D3D12 (native `offloader.exe`) CI (RT-capable runners): - [ ] windows-nvidia D3D12 (`RaytracingTier 1.2`) - [ ] windows-intel VK (`VK_KHR_ray_tracing_pipeline`) - [x] macOS Metal (`supportsRaytracing`) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: EmilioLaiso <emilio@traverseresearch.nl>
1 parent 00dd753 commit 1ddd4d9

8 files changed

Lines changed: 393 additions & 17 deletions

File tree

include/API/Device.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ struct TraditionalRasterPipelineCreateDesc {
121121
case Stages::Compute:
122122
case Stages::Amplification:
123123
case Stages::Mesh:
124+
case Stages::RayGeneration:
125+
case Stages::Miss:
126+
case Stages::ClosestHit:
127+
case Stages::AnyHit:
128+
case Stages::Intersection:
129+
case Stages::Callable:
124130
llvm_unreachable("Not a traditional raster pipeline stage.");
125131
}
126132
}
@@ -151,6 +157,12 @@ struct MeshShaderRasterPipelineCreateDesc {
151157
case Stages::Domain:
152158
case Stages::Geometry:
153159
case Stages::Compute:
160+
case Stages::RayGeneration:
161+
case Stages::Miss:
162+
case Stages::ClosestHit:
163+
case Stages::AnyHit:
164+
case Stages::Intersection:
165+
case Stages::Callable:
154166
llvm_unreachable("Not a mesh raster pipeline stage.");
155167
}
156168
}

include/Support/Pipeline.h

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,53 @@ enum class Stages {
4141

4242
// Mesh Shader Raster
4343
Amplification,
44-
Mesh
44+
Mesh,
45+
46+
// Ray Tracing
47+
RayGeneration,
48+
Miss,
49+
ClosestHit,
50+
AnyHit,
51+
Intersection,
52+
Callable
4553
};
4654
inline constexpr std::array AllStages = {
47-
Stages::Compute, Stages::Vertex, Stages::Hull, Stages::Domain,
48-
Stages::Geometry, Stages::Pixel, Stages::Amplification, Stages::Mesh,
55+
Stages::Compute, Stages::Vertex, Stages::Hull,
56+
Stages::Domain, Stages::Geometry, Stages::Pixel,
57+
Stages::Amplification, Stages::Mesh, Stages::RayGeneration,
58+
Stages::Miss, Stages::ClosestHit, Stages::AnyHit,
59+
Stages::Intersection, Stages::Callable,
4960
};
5061
inline constexpr size_t NumStages = AllStages.size();
5162

52-
enum class ShaderPipelineKind { Compute, TraditionalRaster, MeshShaderRaster };
63+
inline constexpr bool isRayTracingStage(Stages S) {
64+
switch (S) {
65+
case Stages::RayGeneration:
66+
case Stages::Miss:
67+
case Stages::ClosestHit:
68+
case Stages::AnyHit:
69+
case Stages::Intersection:
70+
case Stages::Callable:
71+
return true;
72+
case Stages::Compute:
73+
case Stages::Vertex:
74+
case Stages::Hull:
75+
case Stages::Domain:
76+
case Stages::Geometry:
77+
case Stages::Pixel:
78+
case Stages::Amplification:
79+
case Stages::Mesh:
80+
return false;
81+
}
82+
llvm_unreachable("All stages handled");
83+
}
84+
85+
enum class ShaderPipelineKind {
86+
Compute,
87+
TraditionalRaster,
88+
MeshShaderRaster,
89+
RayTracing
90+
};
5391

5492
enum class Rule { BufferExact, BufferFloatULP, BufferFloatEpsilon };
5593

@@ -528,6 +566,40 @@ struct AccelerationStructureDescs {
528566
llvm::SmallVector<TLASDesc, 1> TLAS;
529567
};
530568

569+
enum class HitGroupType { Triangles, Procedural };
570+
571+
struct HitGroup {
572+
std::string Name;
573+
HitGroupType Type = HitGroupType::Triangles;
574+
std::string ClosestHit;
575+
std::optional<std::string> AnyHit;
576+
std::optional<std::string> Intersection;
577+
};
578+
579+
struct RayTracingPipelineConfig {
580+
uint32_t MaxTraceRecursionDepth = 1;
581+
uint32_t MaxPayloadSizeInBytes = 0;
582+
uint32_t MaxAttributeSizeInBytes = 8;
583+
std::optional<uint32_t> PipelineFlags;
584+
};
585+
586+
struct SBTEntry {
587+
// For RayGen / Miss / Callable entries: the shader's Entry name.
588+
// For HitGroup entries: the HitGroup's Name.
589+
std::string ShaderName;
590+
// Optional per-record local-root data, laid out as the local root signature
591+
// describes. Not used during PR1 bring-up; reserved here so the schema is
592+
// stable when local root signatures land.
593+
llvm::SmallVector<uint8_t> LocalRootData;
594+
};
595+
596+
struct ShaderBindingTable {
597+
SBTEntry RayGen;
598+
llvm::SmallVector<SBTEntry> Miss;
599+
llvm::SmallVector<SBTEntry> HitGroup;
600+
llvm::SmallVector<SBTEntry> Callable;
601+
};
602+
531603
struct Pipeline {
532604
ShaderPipelineKind Kind;
533605
llvm::SmallVector<Shader> Shaders;
@@ -541,6 +613,9 @@ struct Pipeline {
541613
llvm::SmallVector<DescriptorSet> Sets;
542614
DispatchParametersSet DispatchParameters;
543615
AccelerationStructureDescs AccelStructs;
616+
std::optional<RayTracingPipelineConfig> RTConfig;
617+
llvm::SmallVector<HitGroup> HitGroups;
618+
std::optional<ShaderBindingTable> SBT;
544619

545620
uint32_t getVertexCount() const {
546621
if (DispatchParameters.VertexCount)
@@ -608,6 +683,7 @@ struct Pipeline {
608683
bool isRaster() const {
609684
return isTraditionalRaster() || isMeshShaderRaster();
610685
}
686+
bool isRayTracing() const { return Kind == ShaderPipelineKind::RayTracing; }
611687
};
612688
} // namespace offloadtest
613689

@@ -628,6 +704,8 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::AABBGeometry)
628704
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::BLASDesc)
629705
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::InstanceDesc)
630706
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::TLASDesc)
707+
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::HitGroup)
708+
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::SBTEntry)
631709

632710
namespace llvm {
633711
namespace yaml {
@@ -736,6 +814,22 @@ template <> struct MappingTraits<offloadtest::AccelerationStructureDescs> {
736814
static void mapping(IO &I, offloadtest::AccelerationStructureDescs &D);
737815
};
738816

817+
template <> struct MappingTraits<offloadtest::HitGroup> {
818+
static void mapping(IO &I, offloadtest::HitGroup &G);
819+
};
820+
821+
template <> struct MappingTraits<offloadtest::RayTracingPipelineConfig> {
822+
static void mapping(IO &I, offloadtest::RayTracingPipelineConfig &C);
823+
};
824+
825+
template <> struct MappingTraits<offloadtest::SBTEntry> {
826+
static void mapping(IO &I, offloadtest::SBTEntry &E);
827+
};
828+
829+
template <> struct MappingTraits<offloadtest::ShaderBindingTable> {
830+
static void mapping(IO &I, offloadtest::ShaderBindingTable &S);
831+
};
832+
739833
template <> struct ScalarEnumerationTraits<offloadtest::Rule> {
740834
static void enumeration(IO &I, offloadtest::Rule &V) {
741835
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::Rule::Val)
@@ -887,6 +981,21 @@ template <> struct ScalarEnumerationTraits<offloadtest::Stages> {
887981
ENUM_CASE(Pixel);
888982
ENUM_CASE(Amplification);
889983
ENUM_CASE(Mesh);
984+
ENUM_CASE(RayGeneration);
985+
ENUM_CASE(Miss);
986+
ENUM_CASE(ClosestHit);
987+
ENUM_CASE(AnyHit);
988+
ENUM_CASE(Intersection);
989+
ENUM_CASE(Callable);
990+
#undef ENUM_CASE
991+
}
992+
};
993+
994+
template <> struct ScalarEnumerationTraits<offloadtest::HitGroupType> {
995+
static void enumeration(IO &I, offloadtest::HitGroupType &V) {
996+
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::HitGroupType::Val)
997+
ENUM_CASE(Triangles);
998+
ENUM_CASE(Procedural);
890999
#undef ENUM_CASE
8911000
}
8921001
};

lib/API/DX/Device.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3131,6 +3131,9 @@ class DXDevice : public offloadtest::Device {
31313131
if (auto Err = createGraphicsCommands(P, State))
31323132
return Err;
31333133
llvm::outs() << "Graphics command list created complete.\n";
3134+
} else if (P.isRayTracing()) {
3135+
return llvm::createStringError(
3136+
"RayTracing pipeline not yet supported on DirectX");
31343137
} else {
31353138
return llvm::createStringError("Pipeline was neither Compute nor Raster");
31363139
}

lib/API/MTL/MTLDevice.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ static IRShaderStage getShaderStage(Stages Stage) {
148148
return IRShaderStageAmplification;
149149
case Stages::Mesh:
150150
return IRShaderStageMesh;
151+
case Stages::RayGeneration:
152+
case Stages::Miss:
153+
case Stages::ClosestHit:
154+
case Stages::AnyHit:
155+
case Stages::Intersection:
156+
case Stages::Callable:
157+
llvm_unreachable("RayTracing shaders take a different path on Metal.");
151158
}
152159
llvm_unreachable("All cases handled");
153160
}
@@ -2408,6 +2415,9 @@ class MTLDevice : public offloadtest::Device {
24082415

24092416
if (auto Err = createGraphicsCommands(P, IS))
24102417
return Err;
2418+
} else if (P.isRayTracing()) {
2419+
return llvm::createStringError(
2420+
"RayTracing pipeline not yet supported on Metal");
24112421
}
24122422

24132423
auto SubmitResult = GraphicsQueue.submit(std::move(IS.CB));

lib/API/VK/Device.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ static VkShaderStageFlagBits getShaderStageFlag(Stages Stage) {
223223
return VK_SHADER_STAGE_TASK_BIT_EXT;
224224
case Stages::Mesh:
225225
return VK_SHADER_STAGE_MESH_BIT_EXT;
226+
case Stages::RayGeneration:
227+
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
228+
case Stages::Miss:
229+
return VK_SHADER_STAGE_MISS_BIT_KHR;
230+
case Stages::ClosestHit:
231+
return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
232+
case Stages::AnyHit:
233+
return VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
234+
case Stages::Intersection:
235+
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
236+
case Stages::Callable:
237+
return VK_SHADER_STAGE_CALLABLE_BIT_KHR;
226238
}
227239
llvm_unreachable("All cases handled");
228240
}
@@ -4450,6 +4462,9 @@ class VulkanDevice : public offloadtest::Device {
44504462
if (auto Err = createFrameBuffer(State))
44514463
return Err;
44524464
llvm::outs() << "Frame buffer created.\n";
4465+
} else if (P.isRayTracing()) {
4466+
return llvm::createStringError(
4467+
"RayTracing pipeline not yet supported on Vulkan");
44534468
} else {
44544469
return llvm::createStringError(
44554470
"Pipeline was neither Compute nor Traditional Raster");

0 commit comments

Comments
 (0)