Skip to content

Commit b94a072

Browse files
authored
Merge pull request #2685 from johang88/vulkan-compute
feat: Added compute shader support for vulkan
2 parents 4903592 + 143c3de commit b94a072

15 files changed

Lines changed: 651 additions & 252 deletions

File tree

sources/engine/Stride.Graphics/Vulkan/Buffer.Vulkan.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,18 @@ public unsafe void Recreate(IntPtr dataPointer)
126126
NativePipelineStageMask |= VkPipelineStageFlags.VertexShader | VkPipelineStageFlags.FragmentShader;
127127
}
128128

129+
if ((ViewFlags & BufferFlags.StructuredBuffer) != 0)
130+
{
131+
createInfo.usage |= VkBufferUsageFlags.StorageBuffer;
132+
NativeAccessMask |= VkAccessFlags.UniformRead;
133+
NativePipelineStageMask |= VkPipelineStageFlags.VertexShader | VkPipelineStageFlags.FragmentShader;
134+
135+
if ((ViewFlags & BufferFlags.UnorderedAccess) != 0)
136+
{
137+
NativeAccessMask |= VkAccessFlags.ShaderWrite;
138+
}
139+
}
140+
129141
if ((ViewFlags & BufferFlags.ShaderResource) != 0)
130142
{
131143
createInfo.usage |= VkBufferUsageFlags.UniformTexelBuffer;

sources/engine/Stride.Graphics/Vulkan/CommandList.Vulkan.cs

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ private unsafe void FlushInternal(bool wait)
130130

131131
if (activePipeline != null)
132132
{
133-
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativePipeline);
133+
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativePipeline);
134134
var descriptorSetCopy = descriptorSet;
135-
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &descriptorSetCopy, dynamicOffsetCount: 0, dynamicOffsets: null);
135+
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &descriptorSetCopy, dynamicOffsetCount: 0, dynamicOffsets: null);
136136
}
137137
SetRenderTargetsImpl(depthStencilBuffer, renderTargetCount, renderTargets);
138138
}
@@ -249,7 +249,11 @@ private unsafe void PrepareDraw()
249249

250250
// Lazily set the render pass and frame buffer
251251
EnsureRenderPass();
252+
BindDescriptorSets();
253+
}
252254

255+
private unsafe void BindDescriptorSets()
256+
{
253257
// Keep track of descriptor pool usage
254258
bool isPoolExhausted = ++allocatedSetCount > GraphicsDevice.MaxDescriptorSetCount;
255259
for (int i = 0; i < DescriptorSetLayout.DescriptorTypeCount; i++)
@@ -328,18 +332,28 @@ private unsafe void PrepareDraw()
328332
sType = VkStructureType.WriteDescriptorSet,
329333
descriptorType = mapping.DescriptorType,
330334
dstSet = localDescriptorSet,
331-
dstBinding = (uint) mapping.DestinationBinding,
335+
dstBinding = (uint)mapping.DestinationBinding,
332336
dstArrayElement = 0,
333337
descriptorCount = 1
334338
};
335339

336340
switch (mapping.DescriptorType)
337341
{
338342
case VkDescriptorType.SampledImage:
339-
var texture = heapObject.Value as Texture;
340-
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.ShaderReadOnlyOptimal };
341-
write->pImageInfo = &descriptorData->ImageInfo;
342-
break;
343+
{
344+
var texture = heapObject.Value as Texture;
345+
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.ShaderReadOnlyOptimal };
346+
write->pImageInfo = &descriptorData->ImageInfo;
347+
break;
348+
}
349+
350+
case VkDescriptorType.StorageImage:
351+
{
352+
var texture = heapObject.Value as Texture;
353+
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.General };
354+
write->pImageInfo = &descriptorData->ImageInfo;
355+
break;
356+
}
343357

344358
case VkDescriptorType.Sampler:
345359
var samplerState = heapObject.Value as SamplerState;
@@ -349,7 +363,7 @@ private unsafe void PrepareDraw()
349363

350364
case VkDescriptorType.UniformBuffer:
351365
var buffer = heapObject.Value as Buffer;
352-
descriptorData->BufferInfo = new VkDescriptorBufferInfo { buffer = buffer?.NativeBuffer ?? VkBuffer.Null, offset = (ulong) heapObject.Offset, range = (ulong) heapObject.Size };
366+
descriptorData->BufferInfo = new VkDescriptorBufferInfo { buffer = buffer?.NativeBuffer ?? VkBuffer.Null, offset = (ulong)heapObject.Offset, range = (ulong)heapObject.Size };
353367
write->pBufferInfo = &descriptorData->BufferInfo;
354368
break;
355369

@@ -359,14 +373,20 @@ private unsafe void PrepareDraw()
359373
write->pTexelBufferView = &descriptorData->BufferView;
360374
break;
361375

376+
case VkDescriptorType.StorageBuffer:
377+
buffer = heapObject.Value as Buffer;
378+
descriptorData->BufferInfo = new VkDescriptorBufferInfo { buffer = buffer?.NativeBuffer ?? VkBuffer.Null, offset = (ulong)heapObject.Offset, range = (ulong)(buffer?.SizeInBytes ?? 0)};
379+
write->pBufferInfo = &descriptorData->BufferInfo;
380+
break;
381+
362382
default:
363383
throw new InvalidOperationException();
364384
}
365385
}
366386

367-
vkUpdateDescriptorSets(GraphicsDevice.NativeDevice, (uint) bindingCount, writes, descriptorCopyCount: 0, descriptorCopies: null);
387+
vkUpdateDescriptorSets(GraphicsDevice.NativeDevice, (uint)bindingCount, writes, descriptorCopyCount: 0, descriptorCopies: null);
368388
#endif
369-
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &localDescriptorSet, dynamicOffsetCount: 0, dynamicOffsets: null);
389+
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &localDescriptorSet, dynamicOffsetCount: 0, dynamicOffsets: null);
370390
}
371391

372392
private readonly FastList<VkCopyDescriptorSet> copies = new();
@@ -390,7 +410,7 @@ public void SetPipelineState(PipelineState pipelineState)
390410

391411
activePipeline = pipelineState;
392412

393-
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, pipelineState.NativePipeline);
413+
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, pipelineState.NativePipeline);
394414
}
395415

396416
public unsafe void SetVertexBuffer(int index, Buffer buffer, int offset, int stride)
@@ -446,7 +466,7 @@ public unsafe void ResourceBarrierTransition(GraphicsResource resource, Graphics
446466
case GraphicsResourceState.PixelShaderResource:
447467
texture.NativeLayout = VkImageLayout.ShaderReadOnlyOptimal;
448468
texture.NativeAccessMask = VkAccessFlags.ShaderRead;
449-
texture.NativePipelineStageMask = VkPipelineStageFlags.FragmentShader;
469+
texture.NativePipelineStageMask = VkPipelineStageFlags.FragmentShader | VkPipelineStageFlags.ComputeShader;
450470
break;
451471
case GraphicsResourceState.GenericRead:
452472
texture.NativeLayout = VkImageLayout.General;
@@ -503,6 +523,9 @@ public void SetDescriptorSets(int index, DescriptorSet[] descriptorSets)
503523
/// <inheritdoc />
504524
public void Dispatch(int threadCountX, int threadCountY, int threadCountZ)
505525
{
526+
CleanupRenderPass();
527+
BindDescriptorSets();
528+
vkCmdDispatch(currentCommandList.NativeCommandBuffer, (uint)threadCountX, (uint)threadCountY, (uint)threadCountZ);
506529
}
507530

508531
/// <summary>
@@ -512,6 +535,9 @@ public void Dispatch(int threadCountX, int threadCountY, int threadCountZ)
512535
/// <param name="offsetInBytes">The offset information bytes.</param>
513536
public void Dispatch(Buffer indirectBuffer, int offsetInBytes)
514537
{
538+
CleanupRenderPass();
539+
BindDescriptorSets();
540+
vkCmdDispatchIndirect(currentCommandList.NativeCommandBuffer, indirectBuffer.NativeBuffer, (ulong)offsetInBytes);
515541
}
516542

517543
/// <summary>
@@ -1301,11 +1327,6 @@ public unsafe MappedResource MapSubresource(GraphicsResource resource, int subRe
13011327
throw new InvalidOperationException();
13021328
}
13031329

1304-
if (mapMode == MapMode.WriteDiscard)
1305-
{
1306-
throw new InvalidOperationException("Can't use WriteDiscard on Graphics API that doesn't support renaming");
1307-
}
1308-
13091330
if (mapMode != MapMode.WriteNoOverwrite && mapMode != MapMode.Write)
13101331
{
13111332
// Need to wait?

sources/engine/Stride.Graphics/Vulkan/GraphicsDevice.Vulkan.cs

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,20 @@ public partial class GraphicsDevice
4848

4949
internal HeapPool DescriptorPools;
5050
internal const uint MaxDescriptorSetCount = 256;
51-
internal readonly uint[] MaxDescriptorTypeCounts = new uint[DescriptorSetLayout.DescriptorTypeCount]
52-
{
51+
internal readonly uint[] MaxDescriptorTypeCounts =
52+
[
5353
256, // Sampler
5454
0, // CombinedImageSampler
5555
512, // SampledImage
56-
0, // StorageImage
56+
64, // StorageImage
5757
64, // UniformTexelBuffer
58-
0, // StorageTexelBuffer
58+
64, // StorageTexelBuffer
5959
512, // UniformBuffer
60-
0, // StorageBuffer
60+
64, // StorageBuffer
6161
0, // UniformBufferDynamic
6262
0, // StorageBufferDynamic
6363
0 // InputAttachment
64-
};
64+
];
6565

6666
internal Buffer EmptyTexelBufferInt, EmptyTexelBufferFloat;
6767
internal Texture EmptyTexture;
@@ -264,6 +264,22 @@ private unsafe void InitializePlatformDevice(GraphicsProfile[] graphicsProfiles,
264264
ConstantBufferDataPlacementAlignment = (int)physicalDeviceProperties.limits.minUniformBufferOffsetAlignment;
265265
TimestampFrequency = (long)(1.0e9 / physicalDeviceProperties.limits.timestampPeriod); // Resolution in nanoseconds
266266

267+
// Configure descriptor type max counts
268+
void SetMaxDescriptorTypeCount(VkDescriptorType type, uint limit)
269+
=> MaxDescriptorTypeCounts[(int)type] = Math.Min(MaxDescriptorTypeCounts[(int)type], limit);
270+
271+
SetMaxDescriptorTypeCount(VkDescriptorType.Sampler, physicalDeviceProperties.limits.maxDescriptorSetSamplers);
272+
SetMaxDescriptorTypeCount(VkDescriptorType.CombinedImageSampler, 0); // Not defined.
273+
SetMaxDescriptorTypeCount(VkDescriptorType.SampledImage, physicalDeviceProperties.limits.maxDescriptorSetSampledImages);
274+
SetMaxDescriptorTypeCount(VkDescriptorType.StorageImage, physicalDeviceProperties.limits.maxDescriptorSetStorageImages);
275+
SetMaxDescriptorTypeCount(VkDescriptorType.UniformTexelBuffer, physicalDeviceProperties.limits.maxDescriptorSetSampledImages); // No individual limit
276+
SetMaxDescriptorTypeCount(VkDescriptorType.StorageTexelBuffer, physicalDeviceProperties.limits.maxDescriptorSetStorageImages); // No individual limit
277+
SetMaxDescriptorTypeCount(VkDescriptorType.UniformBuffer, physicalDeviceProperties.limits.maxDescriptorSetUniformBuffers);
278+
SetMaxDescriptorTypeCount(VkDescriptorType.StorageBuffer, physicalDeviceProperties.limits.maxDescriptorSetStorageBuffers);
279+
SetMaxDescriptorTypeCount(VkDescriptorType.UniformBufferDynamic, physicalDeviceProperties.limits.maxDescriptorSetUniformBuffersDynamic);
280+
SetMaxDescriptorTypeCount(VkDescriptorType.StorageBufferDynamic, physicalDeviceProperties.limits.maxDescriptorSetStorageBuffersDynamic);
281+
SetMaxDescriptorTypeCount(VkDescriptorType.InputAttachment, physicalDeviceProperties.limits.maxDescriptorSetInputAttachments);
282+
267283
RequestedProfile = graphicsProfiles.First();
268284

269285
var queueProperties = vkGetPhysicalDeviceQueueFamilyProperties(NativePhysicalDevice);
@@ -292,11 +308,24 @@ private unsafe void InitializePlatformDevice(GraphicsProfile[] graphicsProfiles,
292308
depthClamp = true,
293309
};
294310

311+
vkGetPhysicalDeviceFeatures(NativePhysicalDevice, out var deviceFeatures);
312+
313+
if (deviceFeatures.shaderStorageImageReadWithoutFormat)
314+
{
315+
enabledFeature.shaderStorageImageReadWithoutFormat = true;
316+
}
317+
318+
if (deviceFeatures.shaderStorageImageWriteWithoutFormat)
319+
{
320+
enabledFeature.shaderStorageImageWriteWithoutFormat = true;
321+
}
322+
295323
Span<VkUtf8String> supportedExtensionProperties = stackalloc VkUtf8String[]
296324
{
297325
VK_KHR_SWAPCHAIN_EXTENSION_NAME,
298326
VK_EXT_DEBUG_MARKER_EXTENSION_NAME,
299327
};
328+
300329
var availableExtensionProperties = GetAvailableExtensionProperties(supportedExtensionProperties);
301330
ValidateExtensionPropertiesAvailability(availableExtensionProperties);
302331
var desiredExtensionProperties = new HashSet<VkUtf8String>

0 commit comments

Comments
 (0)