Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/API/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ struct TraditionalRasterPipelineCreateDesc {
case Stages::Compute:
case Stages::Amplification:
case Stages::Mesh:
case Stages::RayGeneration:
case Stages::Miss:
case Stages::ClosestHit:
case Stages::AnyHit:
case Stages::Intersection:
case Stages::Callable:
llvm_unreachable("Not a traditional raster pipeline stage.");
}
}
Expand Down Expand Up @@ -151,6 +157,12 @@ struct MeshShaderRasterPipelineCreateDesc {
case Stages::Domain:
case Stages::Geometry:
case Stages::Compute:
case Stages::RayGeneration:
case Stages::Miss:
case Stages::ClosestHit:
case Stages::AnyHit:
case Stages::Intersection:
case Stages::Callable:
llvm_unreachable("Not a mesh raster pipeline stage.");
}
}
Expand Down
117 changes: 113 additions & 4 deletions include/Support/Pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,53 @@ enum class Stages {

// Mesh Shader Raster
Amplification,
Mesh
Mesh,

// Ray Tracing
RayGeneration,
Miss,
ClosestHit,
AnyHit,
Intersection,
Callable
};
inline constexpr std::array AllStages = {
Stages::Compute, Stages::Vertex, Stages::Hull, Stages::Domain,
Stages::Geometry, Stages::Pixel, Stages::Amplification, Stages::Mesh,
Stages::Compute, Stages::Vertex, Stages::Hull,
Stages::Domain, Stages::Geometry, Stages::Pixel,
Stages::Amplification, Stages::Mesh, Stages::RayGeneration,
Stages::Miss, Stages::ClosestHit, Stages::AnyHit,
Stages::Intersection, Stages::Callable,
};
inline constexpr size_t NumStages = AllStages.size();

enum class ShaderPipelineKind { Compute, TraditionalRaster, MeshShaderRaster };
inline constexpr bool isRayTracingStage(Stages S) {
switch (S) {
case Stages::RayGeneration:
case Stages::Miss:
case Stages::ClosestHit:
case Stages::AnyHit:
case Stages::Intersection:
case Stages::Callable:
return true;
case Stages::Compute:
case Stages::Vertex:
case Stages::Hull:
case Stages::Domain:
case Stages::Geometry:
case Stages::Pixel:
case Stages::Amplification:
case Stages::Mesh:
return false;
}
llvm_unreachable("All stages handled");
}

enum class ShaderPipelineKind {
Compute,
TraditionalRaster,
MeshShaderRaster,
RayTracing
};

enum class Rule { BufferExact, BufferFloatULP, BufferFloatEpsilon };

Expand Down Expand Up @@ -528,6 +566,40 @@ struct AccelerationStructureDescs {
llvm::SmallVector<TLASDesc, 1> TLAS;
};

enum class HitGroupType { Triangles, Procedural };

struct HitGroup {
std::string Name;
HitGroupType Type = HitGroupType::Triangles;
std::string ClosestHit;
std::optional<std::string> AnyHit;
std::optional<std::string> Intersection;
};

struct RayTracingPipelineConfig {
uint32_t MaxTraceRecursionDepth = 1;
uint32_t MaxPayloadSizeInBytes = 0;
uint32_t MaxAttributeSizeInBytes = 8;
std::optional<uint32_t> PipelineFlags;
};

struct SBTEntry {
// For RayGen / Miss / Callable entries: the shader's Entry name.
// For HitGroup entries: the HitGroup's Name.
std::string ShaderName;
// Optional per-record local-root data, laid out as the local root signature
// describes. Not used during PR1 bring-up; reserved here so the schema is
// stable when local root signatures land.
llvm::SmallVector<uint8_t> LocalRootData;
};

struct ShaderBindingTable {
SBTEntry RayGen;
llvm::SmallVector<SBTEntry> Miss;
llvm::SmallVector<SBTEntry> HitGroup;
llvm::SmallVector<SBTEntry> Callable;
};

struct Pipeline {
ShaderPipelineKind Kind;
llvm::SmallVector<Shader> Shaders;
Expand All @@ -541,6 +613,9 @@ struct Pipeline {
llvm::SmallVector<DescriptorSet> Sets;
DispatchParametersSet DispatchParameters;
AccelerationStructureDescs AccelStructs;
std::optional<RayTracingPipelineConfig> RTConfig;
llvm::SmallVector<HitGroup> HitGroups;
std::optional<ShaderBindingTable> SBT;

uint32_t getVertexCount() const {
if (DispatchParameters.VertexCount)
Expand Down Expand Up @@ -608,6 +683,7 @@ struct Pipeline {
bool isRaster() const {
return isTraditionalRaster() || isMeshShaderRaster();
}
bool isRayTracing() const { return Kind == ShaderPipelineKind::RayTracing; }
};
} // namespace offloadtest

Expand All @@ -628,6 +704,8 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::AABBGeometry)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::BLASDesc)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::InstanceDesc)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::TLASDesc)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::HitGroup)
LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::SBTEntry)

namespace llvm {
namespace yaml {
Expand Down Expand Up @@ -736,6 +814,22 @@ template <> struct MappingTraits<offloadtest::AccelerationStructureDescs> {
static void mapping(IO &I, offloadtest::AccelerationStructureDescs &D);
};

template <> struct MappingTraits<offloadtest::HitGroup> {
static void mapping(IO &I, offloadtest::HitGroup &G);
};

template <> struct MappingTraits<offloadtest::RayTracingPipelineConfig> {
static void mapping(IO &I, offloadtest::RayTracingPipelineConfig &C);
};

template <> struct MappingTraits<offloadtest::SBTEntry> {
static void mapping(IO &I, offloadtest::SBTEntry &E);
};

template <> struct MappingTraits<offloadtest::ShaderBindingTable> {
static void mapping(IO &I, offloadtest::ShaderBindingTable &S);
};

template <> struct ScalarEnumerationTraits<offloadtest::Rule> {
static void enumeration(IO &I, offloadtest::Rule &V) {
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::Rule::Val)
Expand Down Expand Up @@ -887,6 +981,21 @@ template <> struct ScalarEnumerationTraits<offloadtest::Stages> {
ENUM_CASE(Pixel);
ENUM_CASE(Amplification);
ENUM_CASE(Mesh);
ENUM_CASE(RayGeneration);
ENUM_CASE(Miss);
ENUM_CASE(ClosestHit);
ENUM_CASE(AnyHit);
ENUM_CASE(Intersection);
ENUM_CASE(Callable);
#undef ENUM_CASE
}
};

template <> struct ScalarEnumerationTraits<offloadtest::HitGroupType> {
static void enumeration(IO &I, offloadtest::HitGroupType &V) {
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::HitGroupType::Val)
ENUM_CASE(Triangles);
ENUM_CASE(Procedural);
#undef ENUM_CASE
}
};
Expand Down
3 changes: 3 additions & 0 deletions lib/API/DX/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3131,6 +3131,9 @@ class DXDevice : public offloadtest::Device {
if (auto Err = createGraphicsCommands(P, State))
return Err;
llvm::outs() << "Graphics command list created complete.\n";
} else if (P.isRayTracing()) {
return llvm::createStringError(
"RayTracing pipeline not yet supported on DirectX");
} else {
return llvm::createStringError("Pipeline was neither Compute nor Raster");
}
Expand Down
10 changes: 10 additions & 0 deletions lib/API/MTL/MTLDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ static IRShaderStage getShaderStage(Stages Stage) {
return IRShaderStageAmplification;
case Stages::Mesh:
return IRShaderStageMesh;
case Stages::RayGeneration:
case Stages::Miss:
case Stages::ClosestHit:
case Stages::AnyHit:
case Stages::Intersection:
case Stages::Callable:
llvm_unreachable("RayTracing shaders take a different path on Metal.");
}
llvm_unreachable("All cases handled");
}
Expand Down Expand Up @@ -2408,6 +2415,9 @@ class MTLDevice : public offloadtest::Device {

if (auto Err = createGraphicsCommands(P, IS))
return Err;
} else if (P.isRayTracing()) {
return llvm::createStringError(
"RayTracing pipeline not yet supported on Metal");
}

auto SubmitResult = GraphicsQueue.submit(std::move(IS.CB));
Expand Down
15 changes: 15 additions & 0 deletions lib/API/VK/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ static VkShaderStageFlagBits getShaderStageFlag(Stages Stage) {
return VK_SHADER_STAGE_TASK_BIT_EXT;
case Stages::Mesh:
return VK_SHADER_STAGE_MESH_BIT_EXT;
case Stages::RayGeneration:
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
case Stages::Miss:
return VK_SHADER_STAGE_MISS_BIT_KHR;
case Stages::ClosestHit:
return VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR;
case Stages::AnyHit:
return VK_SHADER_STAGE_ANY_HIT_BIT_KHR;
case Stages::Intersection:
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
case Stages::Callable:
return VK_SHADER_STAGE_CALLABLE_BIT_KHR;
}
llvm_unreachable("All cases handled");
}
Expand Down Expand Up @@ -4450,6 +4462,9 @@ class VulkanDevice : public offloadtest::Device {
if (auto Err = createFrameBuffer(State))
return Err;
llvm::outs() << "Frame buffer created.\n";
} else if (P.isRayTracing()) {
return llvm::createStringError(
"RayTracing pipeline not yet supported on Vulkan");
} else {
return llvm::createStringError(
"Pipeline was neither Compute nor Traditional Raster");
Expand Down
Loading
Loading