Skip to content

Commit 84855d2

Browse files
Add staging buffer for contributions buffer. Prevent race conditions.
1 parent 120cf3e commit 84855d2

2 files changed

Lines changed: 28 additions & 8 deletions

File tree

include/API/Buffer.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,16 @@ struct BufferCreateDesc {
7070
false};
7171
}
7272

73-
static BufferCreateDesc scratchBuffer() {
73+
static BufferCreateDesc gpuOnlyStorage() {
7474
return BufferCreateDesc{MemoryLocation::GpuOnly,
7575
MemoryBacking::Automatic,
7676
BufferUsage::Storage,
7777
BufferShaderAccessType::Raw,
7878
{},
7979
false};
8080
}
81+
82+
static BufferCreateDesc scratchBuffer() { return gpuOnlyStorage(); }
8183
};
8284

8385
class Buffer {

lib/API/MTL/MTLDevice.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,9 +2681,11 @@ class MTLDevice : public offloadtest::Device {
26812681

26822682
// TODO(manon): We would prefer these to live in GPUOnly memory in the
26832683
// future.
2684-
const BufferCreateDesc BufferDesc = BufferCreateDesc::uploadBuffer();
2685-
auto ContribBufferOrErr = createBuffer("AS-Contributions", BufferDesc,
2686-
InstanceCount * sizeof(uint32_t));
2684+
const BufferCreateDesc ContribBufferDesc =
2685+
BufferCreateDesc::gpuOnlyStorage();
2686+
auto ContribBufferOrErr =
2687+
createBuffer("AS-Contributions", ContribBufferDesc,
2688+
InstanceCount * sizeof(uint32_t));
26872689
if (!ContribBufferOrErr)
26882690
return ContribBufferOrErr.takeError();
26892691
auto ContribBuffer = std::move(*ContribBufferOrErr);
@@ -2702,8 +2704,10 @@ class MTLDevice : public offloadtest::Device {
27022704
Header.accelerationStructureID = AS->gpuResourceID()._impl;
27032705
Header.addressOfInstanceContributions =
27042706
ContribBufferMTL.getBufferPtr()->gpuAddress();
2707+
2708+
const BufferCreateDesc HeaderBufferDesc = BufferCreateDesc::uploadBuffer();
27052709
auto HeaderBufOrErr =
2706-
createBufferWithData(*this, "AS-Header", BufferDesc, &Header,
2710+
createBufferWithData(*this, "AS-Header", HeaderBufferDesc, &Header,
27072711
sizeof(Header), nullptr, nullptr);
27082712
if (!HeaderBufOrErr)
27092713
return HeaderBufOrErr.takeError();
@@ -3055,7 +3059,15 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) {
30553059
InstanceASIdx.push_back(Idx);
30563060
}
30573061

3058-
auto ContribPtrOrErr = AS->ContribBuffer->map();
3062+
const BufferCreateDesc UploadDesc = BufferCreateDesc::uploadBuffer();
3063+
const uint32_t ContribBufferSize = AS->ContribBuffer->getSizeInBytes();
3064+
auto ContribUploadBufferOrErr = CB->Dev->createBuffer(
3065+
"Contrib Upload Buffer", UploadDesc, ContribBufferSize);
3066+
if (!ContribUploadBufferOrErr)
3067+
return ContribUploadBufferOrErr.takeError();
3068+
auto ContribUploadBuffer = std::move(*ContribUploadBufferOrErr);
3069+
3070+
auto ContribPtrOrErr = ContribUploadBuffer->map();
30593071
if (!ContribPtrOrErr)
30603072
return ContribPtrOrErr.takeError();
30613073
uint32_t *ContribPtr = static_cast<uint32_t *>(*ContribPtrOrErr);
@@ -3087,13 +3099,19 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) {
30873099
ContribPtr[I] = Src.InstanceContributionToHitGroupIndex &
30883100
0xffffff; // cut-off to 24-bit to match dx12 and vulkan.
30893101
}
3090-
AS->ContribBuffer->unmap();
3102+
ContribUploadBuffer->unmap();
3103+
3104+
if (auto Err = this->copyBufferToBuffer(*ContribUploadBuffer.get(), 0,
3105+
*AS->ContribBuffer.get(), 0,
3106+
ContribBufferSize))
3107+
return Err;
3108+
3109+
CB->KeepAliveOwned.push_back(std::move(ContribUploadBuffer));
30913110

30923111
const size_t InstByteSize =
30933112
Native.size() *
30943113
sizeof(MTL::AccelerationStructureUserIDInstanceDescriptor);
30953114

3096-
const BufferCreateDesc UploadDesc = BufferCreateDesc::uploadBuffer();
30973115
auto InstBufOrErr = offloadtest::createBufferWithData(
30983116
*CB->Dev, "TLAS-Instances", UploadDesc, Native.data(), InstByteSize,
30993117
nullptr, nullptr);

0 commit comments

Comments
 (0)