Skip to content

Commit 3596fcc

Browse files
MarijnS95claude
andcommitted
Add RT acceleration structure abstraction with size queries and resource allocation
Introduce the foundational types for ray tracing acceleration structures: abstract `AccelerationStructure` base class, geometry/instance descriptors, BLAS/TLAS build-request structs with size queries, and AS resource allocation across DX12, Vulkan, and Metal backends. Recording the actual build commands lands in a follow-up commit on top of the ComputeEncoder abstraction. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0c930ef commit 3596fcc

9 files changed

Lines changed: 988 additions & 41 deletions

File tree

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
//===- AccelerationStructure.h - RT Acceleration Structure Types ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef OFFLOADTEST_API_ACCELERATIONSTRUCTURE_H
10+
#define OFFLOADTEST_API_ACCELERATIONSTRUCTURE_H
11+
12+
#include "API/API.h"
13+
#include "API/Buffer.h"
14+
#include "API/Resources.h"
15+
16+
#include "llvm/ADT/ArrayRef.h"
17+
#include "llvm/ADT/SmallVector.h"
18+
#include "llvm/Support/Error.h"
19+
20+
#include <cstdint>
21+
22+
namespace offloadtest {
23+
24+
enum AccelerationStructureBuildFlags : uint32_t {
25+
BuildFlagNone = 0,
26+
AllowUpdate = 1 << 0,
27+
PreferFastTrace = 1 << 1,
28+
PreferFastBuild = 1 << 2,
29+
};
30+
31+
inline AccelerationStructureBuildFlags
32+
operator|(AccelerationStructureBuildFlags A,
33+
AccelerationStructureBuildFlags B) {
34+
return static_cast<AccelerationStructureBuildFlags>(static_cast<uint32_t>(A) |
35+
static_cast<uint32_t>(B));
36+
}
37+
38+
inline AccelerationStructureBuildFlags
39+
operator&(AccelerationStructureBuildFlags A,
40+
AccelerationStructureBuildFlags B) {
41+
return static_cast<AccelerationStructureBuildFlags>(static_cast<uint32_t>(A) &
42+
static_cast<uint32_t>(B));
43+
}
44+
45+
struct AccelerationStructureSizes {
46+
uint64_t ResultDataMaxSizeInBytes = 0;
47+
uint64_t ScratchDataSizeInBytes = 0;
48+
uint64_t UpdateScratchDataSizeInBytes = 0;
49+
};
50+
51+
struct TriangleGeometryDesc {
52+
Buffer *VertexBuffer = nullptr;
53+
uint64_t VertexBufferOffset = 0;
54+
uint32_t VertexCount = 0;
55+
uint32_t VertexStride = 0;
56+
Format VertexFormat = Format::RGB32Float;
57+
Buffer *IndexBuffer = nullptr;
58+
uint64_t IndexBufferOffset = 0;
59+
uint32_t IndexCount = 0;
60+
IndexFormat IdxFormat = IndexFormat::Uint32;
61+
bool Opaque = true;
62+
};
63+
64+
struct AABBGeometryDesc {
65+
Buffer *AABBBuffer = nullptr;
66+
uint64_t AABBBufferOffset = 0;
67+
uint32_t AABBCount = 0;
68+
uint32_t AABBStride = 24;
69+
bool Opaque = true;
70+
};
71+
72+
class AccelerationStructure;
73+
74+
struct AccelerationStructureInstance {
75+
float Transform[3][4] = {{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}};
76+
uint32_t InstanceID = 0;
77+
uint8_t InstanceMask = 0xFF;
78+
AccelerationStructure *BLAS = nullptr;
79+
};
80+
81+
struct BLASBuildRequest {
82+
llvm::SmallVector<TriangleGeometryDesc> Triangles;
83+
llvm::SmallVector<AABBGeometryDesc> AABBs;
84+
AccelerationStructureBuildFlags Flags = BuildFlagNone;
85+
AccelerationStructureSizes Sizes;
86+
};
87+
88+
struct TLASBuildRequest {
89+
llvm::SmallVector<AccelerationStructureInstance> Instances;
90+
AccelerationStructureBuildFlags Flags = BuildFlagNone;
91+
AccelerationStructureSizes Sizes;
92+
};
93+
94+
inline llvm::Error validateTriangleGeometryDesc(const TriangleGeometryDesc &D) {
95+
if (!D.VertexBuffer)
96+
return llvm::createStringError(
97+
std::errc::invalid_argument,
98+
"TriangleGeometryDesc: VertexBuffer is null.");
99+
if (!isPositionCompatible(D.VertexFormat))
100+
return llvm::createStringError(
101+
std::errc::invalid_argument,
102+
"TriangleGeometryDesc: VertexFormat '%s' is not position-compatible.",
103+
getFormatName(D.VertexFormat).data());
104+
if (D.VertexStride < getFormatSizeInBytes(D.VertexFormat))
105+
return llvm::createStringError(
106+
std::errc::invalid_argument,
107+
"TriangleGeometryDesc: VertexStride (%u) must be >= format size (%u).",
108+
D.VertexStride, getFormatSizeInBytes(D.VertexFormat));
109+
if (D.VertexCount == 0)
110+
return llvm::createStringError(std::errc::invalid_argument,
111+
"TriangleGeometryDesc: VertexCount is 0.");
112+
if (D.IndexBuffer && D.IndexCount == 0)
113+
return llvm::createStringError(
114+
std::errc::invalid_argument,
115+
"TriangleGeometryDesc: IndexBuffer is set but IndexCount is 0.");
116+
if (!D.IndexBuffer && D.IndexCount != 0)
117+
return llvm::createStringError(
118+
std::errc::invalid_argument,
119+
"TriangleGeometryDesc: IndexCount is set but IndexBuffer is null.");
120+
if (D.IndexBuffer && D.IndexCount % 3 != 0)
121+
return llvm::createStringError(
122+
std::errc::invalid_argument,
123+
"TriangleGeometryDesc: IndexCount (%u) must be a multiple of 3.",
124+
D.IndexCount);
125+
if (!D.IndexBuffer && D.VertexCount % 3 != 0)
126+
return llvm::createStringError(
127+
std::errc::invalid_argument,
128+
"TriangleGeometryDesc: VertexCount (%u) must be a multiple of 3 when "
129+
"no index buffer is provided.",
130+
D.VertexCount);
131+
return llvm::Error::success();
132+
}
133+
134+
inline llvm::Error validateAABBGeometryDesc(const AABBGeometryDesc &D) {
135+
if (!D.AABBBuffer)
136+
return llvm::createStringError(std::errc::invalid_argument,
137+
"AABBGeometryDesc: AABBBuffer is null.");
138+
if (D.AABBCount == 0)
139+
return llvm::createStringError(std::errc::invalid_argument,
140+
"AABBGeometryDesc: AABBCount is 0.");
141+
if (D.AABBStride < 24)
142+
return llvm::createStringError(
143+
std::errc::invalid_argument,
144+
"AABBGeometryDesc: AABBStride (%u) must be >= 24.", D.AABBStride);
145+
return llvm::Error::success();
146+
}
147+
148+
inline llvm::Error validateBLASBuildRequest(const BLASBuildRequest &Req) {
149+
if (Req.Triangles.empty() && Req.AABBs.empty())
150+
return llvm::createStringError(
151+
std::errc::invalid_argument,
152+
"BLASBuildRequest: Must have at least one geometry descriptor.");
153+
// Vulkan and Metal forbid mixing triangle and AABB geometry in a single BLAS;
154+
// D3D12 technically permits it but is uncommon. Reject for portability.
155+
if (!Req.Triangles.empty() && !Req.AABBs.empty())
156+
return llvm::createStringError(
157+
std::errc::invalid_argument,
158+
"BLASBuildRequest: Cannot mix triangle and AABB geometry in a "
159+
"single BLAS.");
160+
for (const auto &T : Req.Triangles)
161+
if (auto Err = validateTriangleGeometryDesc(T))
162+
return Err;
163+
for (const auto &A : Req.AABBs)
164+
if (auto Err = validateAABBGeometryDesc(A))
165+
return Err;
166+
return llvm::Error::success();
167+
}
168+
169+
inline llvm::Error validateTLASBuildRequest(const TLASBuildRequest &Req) {
170+
if (Req.Instances.empty())
171+
return llvm::createStringError(
172+
std::errc::invalid_argument,
173+
"TLASBuildRequest: Must have at least one instance.");
174+
for (size_t I = 0; I < Req.Instances.size(); ++I)
175+
if (!Req.Instances[I].BLAS)
176+
return llvm::createStringError(
177+
std::errc::invalid_argument,
178+
"TLASBuildRequest: Instance %zu has a null BLAS pointer.", I);
179+
return llvm::Error::success();
180+
}
181+
182+
class AccelerationStructure {
183+
GPUAPI API;
184+
185+
public:
186+
virtual ~AccelerationStructure();
187+
AccelerationStructure(const AccelerationStructure &) = delete;
188+
AccelerationStructure &operator=(const AccelerationStructure &) = delete;
189+
190+
GPUAPI getAPI() const { return API; }
191+
192+
protected:
193+
explicit AccelerationStructure(GPUAPI API) : API(API) {}
194+
};
195+
196+
} // namespace offloadtest
197+
198+
#endif // OFFLOADTEST_API_ACCELERATIONSTRUCTURE_H

include/API/Device.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "Config.h"
1616

1717
#include "API/API.h"
18+
#include "API/AccelerationStructure.h"
1819
#include "API/Buffer.h"
1920
#include "API/Capabilities.h"
2021
#include "API/CommandBuffer.h"
@@ -226,6 +227,21 @@ class Device {
226227
virtual llvm::Expected<std::unique_ptr<CommandBuffer>>
227228
createCommandBuffer() = 0;
228229

230+
virtual llvm::Expected<BLASBuildRequest> createBLASBuildRequest(
231+
llvm::ArrayRef<TriangleGeometryDesc> Triangles,
232+
llvm::ArrayRef<AABBGeometryDesc> AABBs,
233+
AccelerationStructureBuildFlags Flags = BuildFlagNone) = 0;
234+
235+
virtual llvm::Expected<TLASBuildRequest> createTLASBuildRequest(
236+
llvm::ArrayRef<AccelerationStructureInstance> Instances,
237+
AccelerationStructureBuildFlags Flags = BuildFlagNone) = 0;
238+
239+
virtual llvm::Expected<std::unique_ptr<AccelerationStructure>>
240+
createAccelerationStructure(const BLASBuildRequest &Request) = 0;
241+
242+
virtual llvm::Expected<std::unique_ptr<AccelerationStructure>>
243+
createAccelerationStructure(const TLASBuildRequest &Request) = 0;
244+
229245
virtual ~Device() = 0;
230246

231247
llvm::StringRef getDescription() const { return Description; }

lib/API/DX/DXResources.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ inline D3D12_RESOURCE_FLAGS getDXResourceFlags(TextureUsage Usage) {
8484
return Flags;
8585
}
8686

87+
inline DXGI_FORMAT getDXGIIndexFormat(IndexFormat Fmt) {
88+
switch (Fmt) {
89+
case IndexFormat::Uint16:
90+
return DXGI_FORMAT_R16_UINT;
91+
case IndexFormat::Uint32:
92+
return DXGI_FORMAT_R32_UINT;
93+
}
94+
llvm_unreachable("All IndexFormat cases handled");
95+
}
96+
8797
} // namespace offloadtest
8898

8999
#endif // OFFLOADTEST_API_DXRESOURCES_H

0 commit comments

Comments
 (0)