@@ -2681,9 +2681,11 @@ class MTLDevice : public offloadtest::Device {
26812681
26822682 // TODO(manon): We would prefer these to live in GPUOnly memory in the
26832683 // future.
2684- const BufferCreateDesc BufferDesc = BufferCreateDesc::uploadBuffer ();
2685- auto ContribBufferOrErr = createBuffer (" AS-Contributions" , BufferDesc,
2686- InstanceCount * sizeof (uint32_t ));
2684+ const BufferCreateDesc ContribBufferDesc =
2685+ BufferCreateDesc::gpuOnlyStorage ();
2686+ auto ContribBufferOrErr =
2687+ createBuffer (" AS-Contributions" , ContribBufferDesc,
2688+ InstanceCount * sizeof (uint32_t ));
26872689 if (!ContribBufferOrErr)
26882690 return ContribBufferOrErr.takeError ();
26892691 auto ContribBuffer = std::move (*ContribBufferOrErr);
@@ -2702,8 +2704,10 @@ class MTLDevice : public offloadtest::Device {
27022704 Header.accelerationStructureID = AS ->gpuResourceID ()._impl ;
27032705 Header.addressOfInstanceContributions =
27042706 ContribBufferMTL.getBufferPtr ()->gpuAddress ();
2707+
2708+ const BufferCreateDesc HeaderBufferDesc = BufferCreateDesc::uploadBuffer ();
27052709 auto HeaderBufOrErr =
2706- createBufferWithData (*this , " AS-Header" , BufferDesc , &Header,
2710+ createBufferWithData (*this , " AS-Header" , HeaderBufferDesc , &Header,
27072711 sizeof (Header), nullptr , nullptr );
27082712 if (!HeaderBufOrErr)
27092713 return HeaderBufOrErr.takeError ();
@@ -3055,7 +3059,15 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) {
30553059 InstanceASIdx.push_back (Idx);
30563060 }
30573061
3058- auto ContribPtrOrErr = AS ->ContribBuffer ->map ();
3062+ const BufferCreateDesc UploadDesc = BufferCreateDesc::uploadBuffer ();
3063+ const uint32_t ContribBufferSize = AS ->ContribBuffer ->getSizeInBytes ();
3064+ auto ContribUploadBufferOrErr = CB ->Dev ->createBuffer (
3065+ " Contrib Upload Buffer" , UploadDesc, ContribBufferSize);
3066+ if (!ContribUploadBufferOrErr)
3067+ return ContribUploadBufferOrErr.takeError ();
3068+ auto ContribUploadBuffer = std::move (*ContribUploadBufferOrErr);
3069+
3070+ auto ContribPtrOrErr = ContribUploadBuffer->map ();
30593071 if (!ContribPtrOrErr)
30603072 return ContribPtrOrErr.takeError ();
30613073 uint32_t *ContribPtr = static_cast <uint32_t *>(*ContribPtrOrErr);
@@ -3087,13 +3099,19 @@ llvm::Error MTLComputeEncoder::batchBuildAS(llvm::ArrayRef<ASBuildItem> Items) {
30873099 ContribPtr[I] = Src.InstanceContributionToHitGroupIndex &
30883100 0xffffff ; // cut-off to 24-bit to match dx12 and vulkan.
30893101 }
3090- AS ->ContribBuffer ->unmap ();
3102+ ContribUploadBuffer->unmap ();
3103+
3104+ if (auto Err = this ->copyBufferToBuffer (*ContribUploadBuffer.get (), 0 ,
3105+ *AS ->ContribBuffer .get (), 0 ,
3106+ ContribBufferSize))
3107+ return Err;
3108+
3109+ CB ->KeepAliveOwned .push_back (std::move (ContribUploadBuffer));
30913110
30923111 const size_t InstByteSize =
30933112 Native.size () *
30943113 sizeof (MTL ::AccelerationStructureUserIDInstanceDescriptor);
30953114
3096- const BufferCreateDesc UploadDesc = BufferCreateDesc::uploadBuffer ();
30973115 auto InstBufOrErr = offloadtest::createBufferWithData (
30983116 *CB ->Dev , " TLAS-Instances" , UploadDesc, Native.data (), InstByteSize,
30993117 nullptr , nullptr );
0 commit comments