Skip to content

Commit e7e4cd3

Browse files
MarijnS95claude
andcommitted
Add RayTracing pipeline kind, shader stages, and YAML schema
Foundational bring-up for PSO-based raytracing tracked in #1268. Lays out the framework-side surface (stage enums, pipeline kind, YAML schema, lit infrastructure) so subsequent per-backend bring-up PRs (VK → DX12 → Metal) only have to fill in pipeline-state-object creation, SBT construction, and DispatchRays. No backend can run an RT pipeline yet — each one's executeProgram gains a terminal `else if (P.isRayTracing())` that returns a "not yet supported" error. Pipeline.h gets six new Stages (RayGeneration, Miss, ClosestHit, AnyHit, Intersection, Callable), `ShaderPipelineKind::RayTracing`, an `isRayTracingStage` predicate, and `Pipeline::isRayTracing()`. The declarative YAML schema for an RT pipeline lives alongside the existing AccelerationStructureDescs: a `HitGroup` (Triangles | Procedural, with ClosestHit + optional AnyHit / Intersection entries), a `RayTracingPipelineConfig` block (MaxTraceRecursionDepth, MaxPayloadSizeInBytes, MaxAttributeSizeInBytes, optional PipelineFlags), and a `ShaderBindingTable` block with raygen / miss / hit-group / callable record arrays. SBTEntry carries an optional `LocalRootData` byte array reserved for the upcoming local-root-signature work. validatePipelineKind grows an RT branch: it allows multiple shaders of the same RT stage (a pipeline can have several misses or hit groups — the existing duplicate check would have rejected them), requires at least one RayGeneration, and rejects mixing RT with Compute/Vertex/Mesh. The reverse check rejects HitGroups / RTConfig / SBT on any non-RT pipeline. validateDispatchParameters reinterprets DispatchGroupCount as {Width, Height, Depth} for the eventual DispatchRays and forbids VertexCount on RT. Existing Stages switches grow the six new cases: * VK: getShaderStageFlag maps each RT stage to its VK_SHADER_STAGE_*_KHR bit so PR 2 can build VkPipelineShaderStageCreateInfos for the RT pipeline. * Metal: getShaderStage unreachables on RT (the metal-irconverter RT path takes a different route from the IRShaderStage one). * TraditionalRasterPipelineCreateDesc::setShader adds the RT stages to its existing "not a raster stage" unreachable group. test/lit.cfg.py adds a `%dxc_target_lib` substitution (same compiler, distinct name to signal `-T lib_6_x` library targets at a glance) and a `raytracing-pipeline` available-feature. On DX it tracks RaytracingTier >= 1.0; on Vulkan it aliases off the VK_KHR_ray_tracing_pipeline extension already reported by the device. The extension isn't enabled on the VkDevice yet — that lands in PR 2 — but the lit-level capability detection is independent of what the backend currently consumes, so a developer on a VK box can already see the foundational test routed through the RT path. The foundational test `Feature/RT/raygen-roundtrip.test` exercises the full RT YAML schema in one shape: raygen + miss + closest-hit shaders, a BLAS/TLAS pair, a HitGroups list, RayTracingPipelineConfig, and a ShaderBindingTable. `# REQUIRES: raytracing-pipeline` and `# XFAIL: *` keep it expectedly failing until the per-backend PRs drop entries from the XFAIL list as each one starts dispatching real rays. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent c822105 commit e7e4cd3

8 files changed

Lines changed: 387 additions & 17 deletions

File tree

include/API/Device.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ struct TraditionalRasterPipelineCreateDesc {
120120
case Stages::Compute:
121121
case Stages::Amplification:
122122
case Stages::Mesh:
123+
case Stages::RayGeneration:
124+
case Stages::Miss:
125+
case Stages::ClosestHit:
126+
case Stages::AnyHit:
127+
case Stages::Intersection:
128+
case Stages::Callable:
123129
llvm_unreachable("Not a traditional raster pipeline stage.");
124130
}
125131
}

include/Support/Pipeline.h

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,53 @@ enum class Stages {
4141

4242
// Mesh Shader Raster
4343
Amplification,
44-
Mesh
44+
Mesh,
45+
46+
// Ray Tracing
47+
RayGeneration,
48+
Miss,
49+
ClosestHit,
50+
AnyHit,
51+
Intersection,
52+
Callable
4553
};
4654
inline constexpr std::array AllStages = {
47-
Stages::Compute, Stages::Vertex, Stages::Hull, Stages::Domain,
48-
Stages::Geometry, Stages::Pixel, Stages::Amplification, Stages::Mesh,
55+
Stages::Compute, Stages::Vertex, Stages::Hull,
56+
Stages::Domain, Stages::Geometry, Stages::Pixel,
57+
Stages::Amplification, Stages::Mesh, Stages::RayGeneration,
58+
Stages::Miss, Stages::ClosestHit, Stages::AnyHit,
59+
Stages::Intersection, Stages::Callable,
4960
};
5061
inline constexpr size_t NumStages = AllStages.size();
5162

52-
enum class ShaderPipelineKind { Compute, TraditionalRaster, MeshShaderRaster };
63+
inline constexpr bool isRayTracingStage(Stages S) {
64+
switch (S) {
65+
case Stages::RayGeneration:
66+
case Stages::Miss:
67+
case Stages::ClosestHit:
68+
case Stages::AnyHit:
69+
case Stages::Intersection:
70+
case Stages::Callable:
71+
return true;
72+
case Stages::Compute:
73+
case Stages::Vertex:
74+
case Stages::Hull:
75+
case Stages::Domain:
76+
case Stages::Geometry:
77+
case Stages::Pixel:
78+
case Stages::Amplification:
79+
case Stages::Mesh:
80+
return false;
81+
}
82+
llvm_unreachable("All stages handled");
83+
}
84+
85+
enum class ShaderPipelineKind {
86+
Compute,
87+
TraditionalRaster,
88+
MeshShaderRaster,
89+
RayTracing
90+
};
5391

5492
enum class Rule { BufferExact, BufferFloatULP, BufferFloatEpsilon };
5593

@@ -527,6 +565,40 @@ struct AccelerationStructureDescs {
527565
llvm::SmallVector<TLASDesc, 1> TLAS;
528566
};
529567

568+
enum class HitGroupType { Triangles, Procedural };
569+
570+
struct HitGroup {
571+
std::string Name;
572+
HitGroupType Type = HitGroupType::Triangles;
573+
std::string ClosestHit;
574+
std::optional<std::string> AnyHit;
575+
std::optional<std::string> Intersection;
576+
};
577+
578+
struct RayTracingPipelineConfig {
579+
uint32_t MaxTraceRecursionDepth = 1;
580+
uint32_t MaxPayloadSizeInBytes = 0;
581+
uint32_t MaxAttributeSizeInBytes = 8;
582+
std::optional<uint32_t> PipelineFlags;
583+
};
584+
585+
struct SBTEntry {
586+
// For RayGen / Miss / Callable entries: the shader's Entry name.
587+
// For HitGroup entries: the HitGroup's Name.
588+
std::string ShaderName;
589+
// Optional per-record local-root data, laid out as the local root signature
590+
// describes. Not used during PR1 bring-up; reserved here so the schema is
591+
// stable when local root signatures land.
592+
llvm::SmallVector<uint8_t> LocalRootData;
593+
};
594+
595+
struct ShaderBindingTable {
596+
SBTEntry RayGen;
597+
llvm::SmallVector<SBTEntry> Miss;
598+
llvm::SmallVector<SBTEntry> HitGroup;
599+
llvm::SmallVector<SBTEntry> Callable;
600+
};
601+
530602
struct Pipeline {
531603
ShaderPipelineKind Kind;
532604
llvm::SmallVector<Shader> Shaders;
@@ -540,6 +612,9 @@ struct Pipeline {
540612
llvm::SmallVector<DescriptorSet> Sets;
541613
DispatchParametersSet DispatchParameters;
542614
AccelerationStructureDescs AccelStructs;
615+
std::optional<RayTracingPipelineConfig> RTConfig;
616+
llvm::SmallVector<HitGroup> HitGroups;
617+
std::optional<ShaderBindingTable> SBT;
543618

544619
uint32_t getVertexCount() const {
545620
if (DispatchParameters.VertexCount)
@@ -607,6 +682,7 @@ struct Pipeline {
607682
bool isRaster() const {
608683
return isTraditionalRaster() || isMeshShaderRaster();
609684
}
685+
bool isRayTracing() const { return Kind == ShaderPipelineKind::RayTracing; }
610686
};
611687
} // namespace offloadtest
612688

@@ -627,6 +703,8 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::AABBGeometry)
627703
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::BLASDesc)
628704
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::InstanceDesc)
629705
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::TLASDesc)
706+
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::HitGroup)
707+
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::SBTEntry)
630708

631709
namespace llvm {
632710
namespace yaml {
@@ -735,6 +813,22 @@ template <> struct MappingTraits<offloadtest::AccelerationStructureDescs> {
735813
static void mapping(IO &I, offloadtest::AccelerationStructureDescs &D);
736814
};
737815

816+
template <> struct MappingTraits<offloadtest::HitGroup> {
817+
static void mapping(IO &I, offloadtest::HitGroup &G);
818+
};
819+
820+
template <> struct MappingTraits<offloadtest::RayTracingPipelineConfig> {
821+
static void mapping(IO &I, offloadtest::RayTracingPipelineConfig &C);
822+
};
823+
824+
template <> struct MappingTraits<offloadtest::SBTEntry> {
825+
static void mapping(IO &I, offloadtest::SBTEntry &E);
826+
};
827+
828+
template <> struct MappingTraits<offloadtest::ShaderBindingTable> {
829+
static void mapping(IO &I, offloadtest::ShaderBindingTable &S);
830+
};
831+
738832
template <> struct ScalarEnumerationTraits<offloadtest::Rule> {
739833
static void enumeration(IO &I, offloadtest::Rule &V) {
740834
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::Rule::Val)
@@ -886,6 +980,21 @@ template <> struct ScalarEnumerationTraits<offloadtest::Stages> {
886980
ENUM_CASE(Pixel);
887981
ENUM_CASE(Amplification);
888982
ENUM_CASE(Mesh);
983+
ENUM_CASE(RayGeneration);
984+
ENUM_CASE(Miss);
985+
ENUM_CASE(ClosestHit);
986+
ENUM_CASE(AnyHit);
987+
ENUM_CASE(Intersection);
988+
ENUM_CASE(Callable);
989+
#undef ENUM_CASE
990+
}
991+
};
992+
993+
template <> struct ScalarEnumerationTraits<offloadtest::HitGroupType> {
994+
static void enumeration(IO &I, offloadtest::HitGroupType &V) {
995+
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::HitGroupType::Val)
996+
ENUM_CASE(Triangles);
997+
ENUM_CASE(Procedural);
889998
#undef ENUM_CASE
890999
}
8911000
};

lib/API/DX/Device.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2879,6 +2879,9 @@ class DXDevice : public offloadtest::Device {
28792879
if (auto Err = createGraphicsCommands(P, State))
28802880
return Err;
28812881
llvm::outs() << "Graphics command list created complete.\n";
2882+
} else if (P.isRayTracing()) {
2883+
return llvm::createStringError(
2884+
"RayTracing pipeline not yet supported on DirectX");
28822885
} else {
28832886
return llvm::createStringError("Pipeline was neither Compute nor Raster");
28842887
}

lib/API/MTL/MTLDevice.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,13 @@ static IRShaderStage getShaderStage(Stages Stage) {
148148
return IRShaderStageAmplification;
149149
case Stages::Mesh:
150150
return IRShaderStageMesh;
151+
case Stages::RayGeneration:
152+
case Stages::Miss:
153+
case Stages::ClosestHit:
154+
case Stages::AnyHit:
155+
case Stages::Intersection:
156+
case Stages::Callable:
157+
llvm_unreachable("RayTracing shaders take a different path on Metal.");
151158
}
152159
llvm_unreachable("All cases handled");
153160
}
@@ -2382,6 +2389,9 @@ class MTLDevice : public offloadtest::Device {
23822389

23832390
if (auto Err = createGraphicsCommands(P, IS))
23842391
return Err;
2392+
} else if (P.isRayTracing()) {
2393+
return llvm::createStringError(
2394+
"RayTracing pipeline not yet supported on Metal");
23852395
}
23862396

23872397
auto SubmitResult = GraphicsQueue.submit(std::move(IS.CB));

lib/API/VK/Device.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,18 @@ static VkShaderStageFlagBits getShaderStageFlag(Stages Stage) {
223223
return VK_SHADER_STAGE_TASK_BIT_EXT;
224224
case Stages::Mesh:
225225
return VK_SHADER_STAGE_MESH_BIT_EXT;
226+
case Stages::RayGeneration:
227+
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
228+
case Stages::Miss:
229+
return VK_SHADER_STAGE_MISS_BIT_KHR;
230+
case Stages::ClosestHit:
231+
return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
232+
case Stages::AnyHit:
233+
return VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
234+
case Stages::Intersection:
235+
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
236+
case Stages::Callable:
237+
return VK_SHADER_STAGE_CALLABLE_BIT_KHR;
226238
}
227239
llvm_unreachable("All cases handled");
228240
}
@@ -4280,6 +4292,9 @@ class VulkanDevice : public offloadtest::Device {
42804292
if (auto Err = createFrameBuffer(State))
42814293
return Err;
42824294
llvm::outs() << "Frame buffer created.\n";
4295+
} else if (P.isRayTracing()) {
4296+
return llvm::createStringError(
4297+
"RayTracing pipeline not yet supported on Vulkan");
42834298
} else {
42844299
return llvm::createStringError(
42854300
"Pipeline was neither Compute nor Traditional Raster");

0 commit comments

Comments
 (0)