1414#define OFFLOADTEST_SUPPORT_PIPELINE_H
1515
1616#include " API/Enums.h"
17+ #include " API/Resources.h"
1718#include " llvm/ADT/SmallVector.h"
1819#include " llvm/ADT/StringRef.h"
1920#include " llvm/Support/Error.h"
@@ -91,6 +92,8 @@ static inline DescriptorKind getDescriptorKind(ResourceKind RK) {
9192 return DescriptorKind::SAMPLER ;
9293 case ResourceKind::SampledTexture2D:
9394 llvm_unreachable (" Sampled textures aren't supported!" );
95+ case ResourceKind::AccelerationStructure:
96+ return DescriptorKind::SRV ;
9497 }
9598 llvm_unreachable (" All cases handled" );
9699}
@@ -217,6 +220,8 @@ struct Result {
217220 double Epsilon;
218221};
219222
223+ struct TLASDesc ;
224+
220225struct Resource {
221226 ResourceKind Kind;
222227 std::string Name;
@@ -227,6 +232,11 @@ struct Resource {
227232 bool HasCounter;
228233 std::optional<uint32_t > TilesMapped;
229234 bool IsReserved = false ;
235+ TLASDesc *TLASPtr = nullptr ;
236+
237+ bool isAccelerationStructure () const {
238+ return Kind == ResourceKind::AccelerationStructure;
239+ }
230240
231241 bool isRaw () const {
232242 switch (Kind) {
@@ -236,6 +246,7 @@ struct Resource {
236246 case ResourceKind::RWTexture2D:
237247 case ResourceKind::Sampler:
238248 case ResourceKind::SampledTexture2D:
249+ case ResourceKind::AccelerationStructure:
239250 return false ;
240251 case ResourceKind::StructuredBuffer:
241252 case ResourceKind::RWStructuredBuffer:
@@ -261,6 +272,7 @@ struct Resource {
261272 case ResourceKind::Texture2D:
262273 case ResourceKind::RWTexture2D:
263274 case ResourceKind::SampledTexture2D:
275+ case ResourceKind::AccelerationStructure:
264276 return false ;
265277 }
266278 llvm_unreachable (" All cases handled" );
@@ -276,6 +288,7 @@ struct Resource {
276288 case ResourceKind::RWByteAddressBuffer:
277289 case ResourceKind::ConstantBuffer:
278290 case ResourceKind::Sampler:
291+ case ResourceKind::AccelerationStructure:
279292 return false ;
280293 case ResourceKind::Texture2D:
281294 case ResourceKind::RWTexture2D:
@@ -315,18 +328,22 @@ struct Resource {
315328 }
316329
317330 uint32_t getElementSize () const {
318- assert (!isSampler () && " Samplers do not have element size" );
331+ assert (!isSampler () && !isAccelerationStructure () &&
332+ " Samplers and AS do not have element size" );
319333 // ByteAddressBuffers are treated as 4-byte elements to match their memory
320334 // format.
321335 return isByteAddressBuffer () ? 4 : BufferPtr->getElementSize ();
322336 }
323337
324338 uint32_t getArraySize () const {
325- return isSampler () ? 1 : BufferPtr->ArraySize ;
339+ if (isSampler () || isAccelerationStructure ())
340+ return 1 ;
341+ return BufferPtr->ArraySize ;
326342 }
327343
328344 uint32_t size () const {
329- assert (!isSampler () && " Samplers do not have size" );
345+ assert (!isSampler () && !isAccelerationStructure () &&
346+ " Samplers and AS do not have size" );
330347 return BufferPtr->size ();
331348 }
332349
@@ -339,6 +356,7 @@ struct Resource {
339356 case ResourceKind::ConstantBuffer:
340357 case ResourceKind::Sampler:
341358 case ResourceKind::SampledTexture2D:
359+ case ResourceKind::AccelerationStructure:
342360 return false ;
343361 case ResourceKind::RWBuffer:
344362 case ResourceKind::RWStructuredBuffer:
@@ -454,6 +472,53 @@ struct DispatchParametersSet {
454472 std::optional<uint32_t > VertexCount;
455473};
456474
475+ enum class IndexFormat { Uint16, Uint32 };
476+
477+ struct TriangleGeometry {
478+ std::string VertexBuffer;
479+ CPUBuffer *VertexBufferPtr = nullptr ;
480+ Format VertexFormat = Format::RGB32Float;
481+ uint32_t VertexStride = 12 ;
482+ uint32_t VertexCount = 0 ;
483+ std::string IndexBuffer;
484+ CPUBuffer *IndexBufferPtr = nullptr ;
485+ IndexFormat IdxFormat = IndexFormat::Uint32;
486+ uint32_t IndexCount = 0 ;
487+ bool Opaque = true ;
488+ };
489+
490+ struct AABBGeometry {
491+ std::string AABBBuffer;
492+ CPUBuffer *AABBBufferPtr = nullptr ;
493+ uint32_t AABBCount = 0 ;
494+ uint32_t AABBStride = 24 ;
495+ bool Opaque = true ;
496+ };
497+
498+ struct BLASDesc {
499+ std::string Name;
500+ llvm::SmallVector<TriangleGeometry> Triangles;
501+ llvm::SmallVector<AABBGeometry> AABBs;
502+ };
503+
504+ struct InstanceDesc {
505+ std::string BLAS ;
506+ int BLASIdx = -1 ;
507+ float Transform[12 ] = {1 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 , 1 , 0 };
508+ uint32_t InstanceID = 0 ;
509+ uint8_t InstanceMask = 0xFF ;
510+ };
511+
512+ struct TLASDesc {
513+ std::string Name;
514+ llvm::SmallVector<InstanceDesc> Instances;
515+ };
516+
517+ struct AccelerationStructureDescs {
518+ llvm::SmallVector<BLASDesc> BLAS ;
519+ llvm::SmallVector<TLASDesc> TLAS ;
520+ };
521+
457522struct Pipeline {
458523 ShaderPipelineKind Kind;
459524 llvm::SmallVector<Shader> Shaders;
@@ -466,6 +531,7 @@ struct Pipeline {
466531 llvm::SmallVector<Result> Results;
467532 llvm::SmallVector<DescriptorSet> Sets;
468533 DispatchParametersSet DispatchParameters;
534+ AccelerationStructureDescs AccelStructs;
469535
470536 uint32_t getVertexCount () const {
471537 if (DispatchParameters.VertexCount )
@@ -506,6 +572,20 @@ struct Pipeline {
506572 return nullptr ;
507573 }
508574
575+ BLASDesc *getBLAS (llvm::StringRef Name) {
576+ for (auto &B : AccelStructs.BLAS )
577+ if (Name == B.Name )
578+ return &B;
579+ return nullptr ;
580+ }
581+
582+ TLASDesc *getTLAS (llvm::StringRef Name) {
583+ for (auto &T : AccelStructs.TLAS )
584+ if (Name == T.Name )
585+ return &T;
586+ return nullptr ;
587+ }
588+
509589 llvm::Error validatePipelineKind ();
510590 llvm::Error validateDispatchParameters ();
511591
@@ -534,6 +614,11 @@ LLVM_YAML_IS_SEQUENCE_VECTOR(offloadtest::SpecializationConstant)
534614LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::PushConstantBlock)
535615LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::PushConstantValue)
536616LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::DispatchParametersSet)
617+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::TriangleGeometry)
618+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::AABBGeometry)
619+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::BLASDesc)
620+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::InstanceDesc)
621+ LLVM_YAML_IS_SEQUENCE_VECTOR (offloadtest::TLASDesc)
537622
538623namespace llvm {
539624namespace yaml {
@@ -618,6 +703,30 @@ template <> struct MappingTraits<offloadtest::SpecializationConstant> {
618703 static void mapping (IO &I, offloadtest::SpecializationConstant &C);
619704};
620705
706+ template <> struct MappingTraits <offloadtest::TriangleGeometry> {
707+ static void mapping (IO &I, offloadtest::TriangleGeometry &G);
708+ };
709+
710+ template <> struct MappingTraits <offloadtest::AABBGeometry> {
711+ static void mapping (IO &I, offloadtest::AABBGeometry &G);
712+ };
713+
714+ template <> struct MappingTraits <offloadtest::BLASDesc> {
715+ static void mapping (IO &I, offloadtest::BLASDesc &D);
716+ };
717+
718+ template <> struct MappingTraits <offloadtest::InstanceDesc> {
719+ static void mapping (IO &I, offloadtest::InstanceDesc &D);
720+ };
721+
722+ template <> struct MappingTraits <offloadtest::TLASDesc> {
723+ static void mapping (IO &I, offloadtest::TLASDesc &D);
724+ };
725+
726+ template <> struct MappingTraits <offloadtest::AccelerationStructureDescs> {
727+ static void mapping (IO &I, offloadtest::AccelerationStructureDescs &D);
728+ };
729+
621730template <> struct ScalarEnumerationTraits <offloadtest::Rule> {
622731 static void enumeration (IO &I, offloadtest::Rule &V) {
623732#define ENUM_CASE (Val ) I.enumCase(V, #Val, offloadtest::Rule::Val)
@@ -719,6 +828,41 @@ template <> struct ScalarEnumerationTraits<offloadtest::ResourceKind> {
719828 ENUM_CASE (ConstantBuffer);
720829 ENUM_CASE (Sampler);
721830 ENUM_CASE (SampledTexture2D);
831+ ENUM_CASE (AccelerationStructure);
832+ #undef ENUM_CASE
833+ }
834+ };
835+
836+ template <> struct ScalarEnumerationTraits <offloadtest::Format> {
837+ static void enumeration (IO &I, offloadtest::Format &V) {
838+ #define ENUM_CASE (Val ) I.enumCase(V, #Val, offloadtest::Format::Val)
839+ ENUM_CASE (R16Sint);
840+ ENUM_CASE (R16Uint);
841+ ENUM_CASE (RG16Sint);
842+ ENUM_CASE (RG16Uint);
843+ ENUM_CASE (RGBA16Sint);
844+ ENUM_CASE (RGBA16Uint);
845+ ENUM_CASE (R32Sint);
846+ ENUM_CASE (R32Uint);
847+ ENUM_CASE (R32Float);
848+ ENUM_CASE (RG32Sint);
849+ ENUM_CASE (RG32Uint);
850+ ENUM_CASE (RG32Float);
851+ ENUM_CASE (RGB32Float);
852+ ENUM_CASE (RGBA32Sint);
853+ ENUM_CASE (RGBA32Uint);
854+ ENUM_CASE (RGBA32Float);
855+ ENUM_CASE (D32Float);
856+ ENUM_CASE (D32FloatS8Uint);
857+ #undef ENUM_CASE
858+ }
859+ };
860+
861+ template <> struct ScalarEnumerationTraits <offloadtest::IndexFormat> {
862+ static void enumeration (IO &I, offloadtest::IndexFormat &V) {
863+ #define ENUM_CASE (Val ) I.enumCase(V, #Val, offloadtest::IndexFormat::Val)
864+ ENUM_CASE (Uint16);
865+ ENUM_CASE (Uint32);
722866#undef ENUM_CASE
723867 }
724868};
0 commit comments