Skip to content

Commit 1355baa

Browse files
committed
feat: Added compute shader support for vulkan
Vulkan graphics backend has been modified to support compute shaders, additional modifications were also made to the shader compiler so that correct glsl compute shaders can be generated.
1 parent 665a05e commit 1355baa

12 files changed

Lines changed: 452 additions & 224 deletions

File tree

sources/engine/Stride.Graphics/Vulkan/CommandList.Vulkan.cs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ private unsafe void FlushInternal(bool wait)
130130

131131
if (activePipeline != null)
132132
{
133-
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativePipeline);
133+
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativePipeline);
134134
var descriptorSetCopy = descriptorSet;
135-
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &descriptorSetCopy, dynamicOffsetCount: 0, dynamicOffsets: null);
135+
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &descriptorSetCopy, dynamicOffsetCount: 0, dynamicOffsets: null);
136136
}
137137
SetRenderTargetsImpl(depthStencilBuffer, renderTargetCount, renderTargets);
138138
}
@@ -249,7 +249,11 @@ private unsafe void PrepareDraw()
249249

250250
// Lazily set the render pass and frame buffer
251251
EnsureRenderPass();
252+
BindDescriptorSets();
253+
}
252254

255+
private unsafe void BindDescriptorSets()
256+
{
253257
// Keep track of descriptor pool usage
254258
bool isPoolExhausted = ++allocatedSetCount > GraphicsDevice.MaxDescriptorSetCount;
255259
for (int i = 0; i < DescriptorSetLayout.DescriptorTypeCount; i++)
@@ -328,18 +332,28 @@ private unsafe void PrepareDraw()
328332
sType = VkStructureType.WriteDescriptorSet,
329333
descriptorType = mapping.DescriptorType,
330334
dstSet = localDescriptorSet,
331-
dstBinding = (uint) mapping.DestinationBinding,
335+
dstBinding = (uint)mapping.DestinationBinding,
332336
dstArrayElement = 0,
333337
descriptorCount = 1
334338
};
335339

336340
switch (mapping.DescriptorType)
337341
{
338342
case VkDescriptorType.SampledImage:
339-
var texture = heapObject.Value as Texture;
340-
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.ShaderReadOnlyOptimal };
341-
write->pImageInfo = &descriptorData->ImageInfo;
342-
break;
343+
{
344+
var texture = heapObject.Value as Texture;
345+
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.ShaderReadOnlyOptimal };
346+
write->pImageInfo = &descriptorData->ImageInfo;
347+
break;
348+
}
349+
350+
case VkDescriptorType.StorageImage:
351+
{
352+
var texture = heapObject.Value as Texture;
353+
descriptorData->ImageInfo = new VkDescriptorImageInfo { imageView = texture?.NativeImageView ?? GraphicsDevice.EmptyTexture.NativeImageView, imageLayout = VkImageLayout.General };
354+
write->pImageInfo = &descriptorData->ImageInfo;
355+
break;
356+
}
343357

344358
case VkDescriptorType.Sampler:
345359
var samplerState = heapObject.Value as SamplerState;
@@ -349,7 +363,7 @@ private unsafe void PrepareDraw()
349363

350364
case VkDescriptorType.UniformBuffer:
351365
var buffer = heapObject.Value as Buffer;
352-
descriptorData->BufferInfo = new VkDescriptorBufferInfo { buffer = buffer?.NativeBuffer ?? VkBuffer.Null, offset = (ulong) heapObject.Offset, range = (ulong) heapObject.Size };
366+
descriptorData->BufferInfo = new VkDescriptorBufferInfo { buffer = buffer?.NativeBuffer ?? VkBuffer.Null, offset = (ulong)heapObject.Offset, range = (ulong)heapObject.Size };
353367
write->pBufferInfo = &descriptorData->BufferInfo;
354368
break;
355369

@@ -364,9 +378,9 @@ private unsafe void PrepareDraw()
364378
}
365379
}
366380

367-
vkUpdateDescriptorSets(GraphicsDevice.NativeDevice, (uint) bindingCount, writes, descriptorCopyCount: 0, descriptorCopies: null);
381+
vkUpdateDescriptorSets(GraphicsDevice.NativeDevice, (uint)bindingCount, writes, descriptorCopyCount: 0, descriptorCopies: null);
368382
#endif
369-
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &localDescriptorSet, dynamicOffsetCount: 0, dynamicOffsets: null);
383+
vkCmdBindDescriptorSets(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, activePipeline.NativeLayout, firstSet: 0, descriptorSetCount: 1, &localDescriptorSet, dynamicOffsetCount: 0, dynamicOffsets: null);
370384
}
371385

372386
private readonly FastList<VkCopyDescriptorSet> copies = new();
@@ -390,7 +404,7 @@ public void SetPipelineState(PipelineState pipelineState)
390404

391405
activePipeline = pipelineState;
392406

393-
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, VkPipelineBindPoint.Graphics, pipelineState.NativePipeline);
407+
vkCmdBindPipeline(currentCommandList.NativeCommandBuffer, activePipeline.IsCompute ? VkPipelineBindPoint.Compute : VkPipelineBindPoint.Graphics, pipelineState.NativePipeline);
394408
}
395409

396410
public unsafe void SetVertexBuffer(int index, Buffer buffer, int offset, int stride)
@@ -446,7 +460,7 @@ public unsafe void ResourceBarrierTransition(GraphicsResource resource, Graphics
446460
case GraphicsResourceState.PixelShaderResource:
447461
texture.NativeLayout = VkImageLayout.ShaderReadOnlyOptimal;
448462
texture.NativeAccessMask = VkAccessFlags.ShaderRead;
449-
texture.NativePipelineStageMask = VkPipelineStageFlags.FragmentShader;
463+
texture.NativePipelineStageMask = VkPipelineStageFlags.FragmentShader | VkPipelineStageFlags.ComputeShader; // TODO: Not sure why I did this can probably double check ...
450464
break;
451465
case GraphicsResourceState.GenericRead:
452466
texture.NativeLayout = VkImageLayout.General;
@@ -503,6 +517,9 @@ public void SetDescriptorSets(int index, DescriptorSet[] descriptorSets)
503517
/// <inheritdoc />
504518
public void Dispatch(int threadCountX, int threadCountY, int threadCountZ)
505519
{
520+
CleanupRenderPass();
521+
BindDescriptorSets();
522+
vkCmdDispatch(currentCommandList.NativeCommandBuffer, (uint)threadCountX, (uint)threadCountY, (uint)threadCountZ);
506523
}
507524

508525
/// <summary>
@@ -512,6 +529,9 @@ public void Dispatch(int threadCountX, int threadCountY, int threadCountZ)
512529
/// <param name="offsetInBytes">The offset information bytes.</param>
513530
public void Dispatch(Buffer indirectBuffer, int offsetInBytes)
514531
{
532+
CleanupRenderPass();
533+
BindDescriptorSets();
534+
vkCmdDispatchIndirect(currentCommandList.NativeCommandBuffer, indirectBuffer.NativeBuffer, (ulong)offsetInBytes);
515535
}
516536

517537
/// <summary>

sources/engine/Stride.Graphics/Vulkan/GraphicsDevice.Vulkan.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ public partial class GraphicsDevice
5252
256, // Sampler
5353
0, // CombinedImageSampler
5454
512, // SampledImage
55-
0, // StorageImage
55+
32, // StorageImage
5656
64, // UniformTexelBuffer
57-
0, // StorageTexelBuffer
57+
32, // StorageTexelBuffer
5858
512, // UniformBuffer
5959
0, // StorageBuffer
6060
0, // UniformBufferDynamic
@@ -287,6 +287,12 @@ private unsafe void InitializePlatformDevice(GraphicsProfile[] graphicsProfiles,
287287
depthClamp = true,
288288
};
289289

290+
if (graphicsProfiles.Any(x => x >= GraphicsProfile.Level_11_0))
291+
{
292+
enabledFeature.shaderStorageImageReadWithoutFormat = true;
293+
enabledFeature.shaderStorageImageWriteWithoutFormat = true;
294+
}
295+
290296
var extensionProperties = vkEnumerateDeviceExtensionProperties(NativePhysicalDevice);
291297
var availableExtensionNames = new List<string>();
292298
var desiredExtensionNames = new List<string>();

0 commit comments

Comments
 (0)