1414#define OFFLOADTEST_SUPPORT_PIPELINE_H
1515
1616#include " API/Enums.h"
17+ #include " API/Resources.h"
1718#include " llvm/ADT/SmallVector.h"
1819#include " llvm/ADT/StringRef.h"
1920#include " llvm/Support/Error.h"
@@ -80,6 +81,7 @@ static inline DescriptorKind getDescriptorKind(ResourceKind RK) {
8081 case ResourceKind::StructuredBuffer:
8182 case ResourceKind::ByteAddressBuffer:
8283 case ResourceKind::Texture2D:
84+ case ResourceKind::AccelerationStructure:
8385 return DescriptorKind::SRV ;
8486
8587 case ResourceKind::RWStructuredBuffer:
@@ -221,6 +223,8 @@ struct Result {
221223 double Epsilon;
222224};
223225
226+ struct TLASDesc ;
227+
224228struct Resource {
225229 ResourceKind Kind;
226230 std::string Name;
@@ -231,6 +235,11 @@ struct Resource {
231235 bool HasCounter;
232236 std::optional<uint32_t > TilesMapped;
233237 bool IsReserved = false ;
238+ TLASDesc *TLASPtr = nullptr ;
239+
240+ bool isAccelerationStructure () const {
241+ return Kind == ResourceKind::AccelerationStructure;
242+ }
234243
235244 bool isRaw () const {
236245 switch (Kind) {
@@ -240,6 +249,7 @@ struct Resource {
240249 case ResourceKind::RWTexture2D:
241250 case ResourceKind::Sampler:
242251 case ResourceKind::SampledTexture2D:
252+ case ResourceKind::AccelerationStructure:
243253 return false ;
244254 case ResourceKind::StructuredBuffer:
245255 case ResourceKind::RWStructuredBuffer:
@@ -265,6 +275,7 @@ struct Resource {
265275 case ResourceKind::Texture2D:
266276 case ResourceKind::RWTexture2D:
267277 case ResourceKind::SampledTexture2D:
278+ case ResourceKind::AccelerationStructure:
268279 return false ;
269280 }
270281 llvm_unreachable (" All cases handled" );
@@ -280,6 +291,7 @@ struct Resource {
280291 case ResourceKind::RWByteAddressBuffer:
281292 case ResourceKind::ConstantBuffer:
282293 case ResourceKind::Sampler:
294+ case ResourceKind::AccelerationStructure:
283295 return false ;
284296 case ResourceKind::Texture2D:
285297 case ResourceKind::RWTexture2D:
@@ -319,18 +331,22 @@ struct Resource {
319331 }
320332
321333 uint32_t getElementSize () const {
322- assert (!isSampler () && " Samplers do not have element size" );
334+ assert (!isSampler () && !isAccelerationStructure () &&
335+ " Samplers and AS do not have element size" );
323336 // ByteAddressBuffers are treated as 4-byte elements to match their memory
324337 // format.
325338 return isByteAddressBuffer () ? 4 : BufferPtr->getElementSize ();
326339 }
327340
328341 uint32_t getArraySize () const {
329- return isSampler () ? 1 : BufferPtr->ArraySize ;
342+ if (isSampler () || isAccelerationStructure ())
343+ return 1 ;
344+ return BufferPtr->ArraySize ;
330345 }
331346
332347 uint32_t size () const {
333- assert (!isSampler () && " Samplers do not have size" );
348+ assert (!isSampler () && !isAccelerationStructure () &&
349+ " Samplers and AS do not have size" );
334350 return BufferPtr->size ();
335351 }
336352
@@ -343,6 +359,7 @@ struct Resource {
343359 case ResourceKind::ConstantBuffer:
344360 case ResourceKind::Sampler:
345361 case ResourceKind::SampledTexture2D:
362+ case ResourceKind::AccelerationStructure:
346363 return false ;
347364 case ResourceKind::RWBuffer:
348365 case ResourceKind::RWStructuredBuffer:
@@ -465,6 +482,51 @@ struct DispatchParametersSet {
465482 std::optional<uint32_t > VertexCount;
466483};
467484
485+ struct TriangleGeometry {
486+ std::string VertexBuffer;
487+ CPUBuffer *VertexBufferPtr = nullptr ;
488+ Format VertexFormat = Format::RGB32Float;
489+ uint32_t VertexStride = 12 ;
490+ uint32_t VertexCount = 0 ;
491+ std::string IndexBuffer;
492+ CPUBuffer *IndexBufferPtr = nullptr ;
493+ IndexFormat IdxFormat = IndexFormat::Uint32;
494+ uint32_t IndexCount = 0 ;
495+ bool Opaque = true ;
496+ };
497+
498+ struct AABBGeometry {
499+ std::string AABBBuffer;
500+ CPUBuffer *AABBBufferPtr = nullptr ;
501+ uint32_t AABBCount = 0 ;
502+ uint32_t AABBStride = 24 ;
503+ bool Opaque = true ;
504+ };
505+
506+ struct BLASDesc {
507+ std::string Name;
508+ llvm::SmallVector<TriangleGeometry> Triangles;
509+ llvm::SmallVector<AABBGeometry> AABBs;
510+ };
511+
512+ struct InstanceDesc {
513+ std::string BLAS ;
514+ int BLASIdx = -1 ;
515+ float Transform[12 ] = {1 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 1 , 0 };
516+ uint32_t InstanceID = 0 ;
517+ uint8_t InstanceMask = 0xFF ;
518+ };
519+
520+ struct TLASDesc {
521+ std::string Name;
522+ llvm::SmallVector<InstanceDesc> Instances;
523+ };
524+
525+ struct AccelerationStructureDescs {
526+ llvm::SmallVector<BLASDesc> BLAS ;
527+ llvm::SmallVector<TLASDesc> TLAS ;
528+ };
529+
468530struct Pipeline {
469531 ShaderPipelineKind Kind;
470532 llvm::SmallVector<Shader> Shaders;
@@ -477,6 +539,7 @@ struct Pipeline {
477539 llvm::SmallVector<Result> Results;
478540 llvm::SmallVector<DescriptorSet> Sets;
479541 DispatchParametersSet DispatchParameters;
542+ AccelerationStructureDescs AccelStructs;
480543
481544 uint32_t getVertexCount () const {
482545 if (DispatchParameters.VertexCount )
@@ -517,6 +580,20 @@ struct Pipeline {
517580 return nullptr ;
518581 }
519582
583+ BLASDesc *getBLAS (llvm::StringRef Name) {
584+ for (auto &B : AccelStructs.BLAS )
585+ if (Name == B.Name )
586+ return &B;
587+ return nullptr ;
588+ }
589+
590+ TLASDesc *getTLAS (llvm::StringRef Name) {
591+ for (auto &T : AccelStructs.TLAS )
592+ if (Name == T.Name )
593+ return &T;
594+ return nullptr ;
595+ }
596+
520597 llvm::Error validatePipelineKind ();
521598 llvm::Error validateDispatchParameters ();
522599
@@ -545,6 +622,11 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::SpecializationConstant)
545622LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::PushConstantBlock)
546623LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::PushConstantValue)
547624LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::DispatchParametersSet)
625+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::TriangleGeometry)
626+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::AABBGeometry)
627+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::BLASDesc)
628+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::InstanceDesc)
629+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::TLASDesc)
548630
549631namespace llvm {
550632namespace yaml {
@@ -629,6 +711,30 @@ template <> struct MappingTraits<offloadtest::SpecializationConstant> {
629711 static void mapping (IO &I, offloadtest::SpecializationConstant &C);
630712};
631713
714+ template <> struct MappingTraits <offloadtest::TriangleGeometry> {
715+ static void mapping (IO &I, offloadtest::TriangleGeometry &G);
716+ };
717+
718+ template <> struct MappingTraits <offloadtest::AABBGeometry> {
719+ static void mapping (IO &I, offloadtest::AABBGeometry &G);
720+ };
721+
722+ template <> struct MappingTraits <offloadtest::BLASDesc> {
723+ static void mapping (IO &I, offloadtest::BLASDesc &D);
724+ };
725+
726+ template <> struct MappingTraits <offloadtest::InstanceDesc> {
727+ static void mapping (IO &I, offloadtest::InstanceDesc &D);
728+ };
729+
730+ template <> struct MappingTraits <offloadtest::TLASDesc> {
731+ static void mapping (IO &I, offloadtest::TLASDesc &D);
732+ };
733+
734+ template <> struct MappingTraits <offloadtest::AccelerationStructureDescs> {
735+ static void mapping (IO &I, offloadtest::AccelerationStructureDescs &D);
736+ };
737+
632738template <> struct ScalarEnumerationTraits <offloadtest::Rule> {
633739 static void enumeration (IO &I, offloadtest::Rule &V) {
634740#define ENUM_CASE (Val ) I.enumCase(V, #Val, offloadtest::Rule::Val)
@@ -730,6 +836,41 @@ template <> struct ScalarEnumerationTraits<offloadtest::ResourceKind> {
730836 ENUM_CASE (ConstantBuffer);
731837 ENUM_CASE (Sampler);
732838 ENUM_CASE (SampledTexture2D);
839+ ENUM_CASE (AccelerationStructure);
840+ #undef ENUM_CASE
841+ }
842+ };
843+
844+ template <> struct ScalarEnumerationTraits <offloadtest::Format> {
845+ static void enumeration (IO &I, offloadtest::Format &V) {
846+ #define ENUM_CASE (Val ) I.enumCase(V, #Val, offloadtest::Format::Val)
847+ ENUM_CASE (R16Sint);
848+ ENUM_CASE (R16Uint);
849+ ENUM_CASE (RG16Sint);
850+ ENUM_CASE (RG16Uint);
851+ ENUM_CASE (RGBA16Sint);
852+ ENUM_CASE (RGBA16Uint);
853+ ENUM_CASE (R32Sint);
854+ ENUM_CASE (R32Uint);
855+ ENUM_CASE (R32Float);
856+ ENUM_CASE (RG32Sint);
857+ ENUM_CASE (RG32Uint);
858+ ENUM_CASE (RG32Float);
859+ ENUM_CASE (RGB32Float);
860+ ENUM_CASE (RGBA32Sint);
861+ ENUM_CASE (RGBA32Uint);
862+ ENUM_CASE (RGBA32Float);
863+ ENUM_CASE (D32Float);
864+ ENUM_CASE (D32FloatS8Uint);
865+ #undef ENUM_CASE
866+ }
867+ };
868+
869+ template <> struct ScalarEnumerationTraits <offloadtest::IndexFormat> {
870+ static void enumeration (IO &I, offloadtest::IndexFormat &V) {
871+ #define ENUM_CASE (Val ) I.enumCase(V, #Val, offloadtest::IndexFormat::Val)
872+ ENUM_CASE (Uint16);
873+ ENUM_CASE (Uint32);
733874#undef ENUM_CASE
734875 }
735876};
0 commit comments