Skip to content

Commit 1affeef

Browse files
authored
Add tessellation support to DX/VK backends with accompanying SimpleTriangleTess.test (#1224)
Adds `Hull` and `Domain` shader support, plus a `PatchList` primitive topology with per-call control-point count. Wires the new stages on the DX12 and Vulkan backends. Metal explicitly rejects them. Metal's tessellation model (HS as a compute kernel, DS as a `[[patch(...)]]`-tagged vertex function) doesn't neatly fit the HS/DS shape, so the path is intentionally left unimplemented. Adds one end-to-end test, `SimpleTriangleTess.test`, exercising a VS → HS → DS → PS pipeline that renders the `SimpleTriangle` image test via tessellation.
1 parent 997b1c7 commit 1affeef

8 files changed

Lines changed: 329 additions & 8 deletions

File tree

include/API/Device.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,15 @@ struct TraditionalRasterPipelineCreateDesc {
8787
llvm::SmallVector<Format> RTFormats;
8888
std::optional<Format> DSFormat;
8989
PrimitiveTopology Topology;
90+
// Set if Topology == PatchList. Validated in
91+
// Pipeline.cpp::validatePipelineKind.
92+
std::optional<uint32_t> PatchControlPoints;
93+
9094
ShaderContainer VS;
91-
// TODO: Optional Hull & Domain Shaders
95+
// Hull and Domain are independent optionals here; Pipeline.cpp enforces that
96+
// they must be set as a pair (and only with PatchList topology).
97+
std::optional<ShaderContainer> HS;
98+
std::optional<ShaderContainer> DS;
9299
std::optional<ShaderContainer> GS;
93100
ShaderContainer PS;
94101

@@ -97,6 +104,12 @@ struct TraditionalRasterPipelineCreateDesc {
97104
case Stages::Vertex:
98105
VS = std::move(SC);
99106
break;
107+
case Stages::Hull:
108+
HS = std::move(SC);
109+
break;
110+
case Stages::Domain:
111+
DS = std::move(SC);
112+
break;
100113
case Stages::Geometry:
101114
GS = std::move(SC);
102115
break;

include/API/Enums.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ enum class StoreAction {
4444
DontCare, ///< Contents may be discarded after the pass.
4545
};
4646

47-
enum class PrimitiveTopology { TriangleList, PointList };
47+
enum class PrimitiveTopology { TriangleList, PointList, PatchList };
4848

4949
} // namespace offloadtest
5050

include/Support/Pipeline.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "llvm/Support/YAMLTraits.h"
2222
#include <limits>
2323
#include <memory>
24+
#include <optional>
2425
#include <string>
2526
#include <variant>
2627

@@ -32,6 +33,8 @@ enum class Stages {
3233

3334
// Traditional Raster
3435
Vertex,
36+
Hull,
37+
Domain,
3538
Geometry,
3639
Pixel,
3740

@@ -40,8 +43,8 @@ enum class Stages {
4043
Mesh
4144
};
4245
inline constexpr std::array AllStages = {
43-
Stages::Compute, Stages::Vertex, Stages::Geometry,
44-
Stages::Pixel, Stages::Amplification, Stages::Mesh,
46+
Stages::Compute, Stages::Vertex, Stages::Hull, Stages::Domain,
47+
Stages::Geometry, Stages::Pixel, Stages::Amplification, Stages::Mesh,
4548
};
4649
inline constexpr size_t NumStages = AllStages.size();
4750

@@ -402,6 +405,12 @@ struct IOBindings {
402405
CPUBuffer *RTargetBufferPtr = nullptr;
403406
PrimitiveTopology Topology = PrimitiveTopology::TriangleList;
404407

408+
// Set if Topology == PatchList. Validated in
409+
// Pipeline.cpp::validatePipelineKind. Valid range is 1..32 (matches both
410+
// D3D12's per-CP-patchlist topologies and Vulkan's
411+
// VkPipelineTessellationStateCreateInfo::patchControlPoints).
412+
std::optional<uint32_t> PatchControlPoints;
413+
405414
uint32_t getVertexStride() const {
406415
uint32_t Stride = 0;
407416
for (auto VA : VertexAttributes)
@@ -730,6 +739,8 @@ template <> struct ScalarEnumerationTraits<offloadtest::Stages> {
730739
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::Stages::Val)
731740
ENUM_CASE(Compute);
732741
ENUM_CASE(Vertex);
742+
ENUM_CASE(Hull);
743+
ENUM_CASE(Domain);
733744
ENUM_CASE(Geometry);
734745
ENUM_CASE(Pixel);
735746
ENUM_CASE(Amplification);
@@ -743,6 +754,7 @@ template <> struct ScalarEnumerationTraits<offloadtest::PrimitiveTopology> {
743754
#define ENUM_CASE(Val) I.enumCase(V, #Val, offloadtest::PrimitiveTopology::Val)
744755
ENUM_CASE(TriangleList);
745756
ENUM_CASE(PointList);
757+
ENUM_CASE(PatchList);
746758
#undef ENUM_CASE
747759
}
748760
};

lib/API/DX/Device.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,27 @@ getDXPrimitiveTopologyType(PrimitiveTopology Topology) {
165165
return D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
166166
case PrimitiveTopology::PointList:
167167
return D3D12_PRIMITIVE_TOPOLOGY_TYPE_POINT;
168+
case PrimitiveTopology::PatchList:
169+
return D3D12_PRIMITIVE_TOPOLOGY_TYPE_PATCH;
168170
}
169171
llvm_unreachable("All PrimitiveTopology cases handled");
170172
}
171173

172174
static D3D_PRIMITIVE_TOPOLOGY
173-
getDXPrimitiveTopology(PrimitiveTopology Topology) {
175+
getDXPrimitiveTopology(PrimitiveTopology Topology,
176+
std::optional<uint32_t> PatchControlPoints) {
174177
switch (Topology) {
175178
case PrimitiveTopology::TriangleList:
176179
return D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST;
177180
case PrimitiveTopology::PointList:
178181
return D3D_PRIMITIVE_TOPOLOGY_POINTLIST;
182+
case PrimitiveTopology::PatchList:
183+
// _N_CONTROL_POINT_PATCHLIST enums are contiguous from 1..32.
184+
assert(PatchControlPoints && *PatchControlPoints >= 1 &&
185+
*PatchControlPoints <= 32);
186+
return static_cast<D3D_PRIMITIVE_TOPOLOGY>(
187+
D3D_PRIMITIVE_TOPOLOGY_1_CONTROL_POINT_PATCHLIST +
188+
(*PatchControlPoints - 1));
179189
}
180190
llvm_unreachable("All PrimitiveTopology cases handled");
181191
}
@@ -1178,6 +1188,12 @@ class DXDevice : public offloadtest::Device {
11781188
return llvm::createStringError(std::errc::invalid_argument,
11791189
"Graphics pipeline requires both a vertex "
11801190
"shader and a pixel shader.");
1191+
if (Desc.HS)
1192+
PSODesc.HS = {Desc.HS->Shader->getBuffer().data(),
1193+
Desc.HS->Shader->getBuffer().size()};
1194+
if (Desc.DS)
1195+
PSODesc.DS = {Desc.DS->Shader->getBuffer().data(),
1196+
Desc.DS->Shader->getBuffer().size()};
11811197
if (Desc.GS)
11821198
PSODesc.GS = {Desc.GS->Shader->getBuffer().data(),
11831199
Desc.GS->Shader->getBuffer().size()};
@@ -1206,7 +1222,8 @@ class DXDevice : public offloadtest::Device {
12061222
return Err;
12071223

12081224
return std::make_unique<DXPipelineState>(
1209-
Name, RootSig, PSO, getDXPrimitiveTopology(Desc.Topology));
1225+
Name, RootSig, PSO,
1226+
getDXPrimitiveTopology(Desc.Topology, Desc.PatchControlPoints));
12101227
}
12111228

12121229
llvm::Expected<std::unique_ptr<PipelineState>>
@@ -2511,6 +2528,7 @@ class DXDevice : public offloadtest::Device {
25112528

25122529
TraditionalRasterPipelineCreateDesc PipelineDesc = {};
25132530
PipelineDesc.Topology = P.Bindings.Topology;
2531+
PipelineDesc.PatchControlPoints = P.Bindings.PatchControlPoints;
25142532
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
25152533
for (auto &Shader : P.Shaders) {
25162534
ShaderContainer SC = {};

lib/API/MTL/MTLDevice.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ static IRShaderStage getShaderStage(Stages Stage) {
133133
return IRShaderStageCompute;
134134
case Stages::Vertex:
135135
return IRShaderStageVertex;
136+
case Stages::Hull:
137+
llvm_unreachable("Hull shaders are not supported on Metal.");
138+
case Stages::Domain:
139+
llvm_unreachable("Domain shaders are not supported on Metal.");
136140
case Stages::Geometry:
137141
llvm_unreachable("Geometry shaders are not supported on Metal.");
138142
case Stages::Pixel:
@@ -1584,6 +1588,11 @@ class MTLDevice : public offloadtest::Device {
15841588
return llvm::createStringError(
15851589
std::errc::not_supported,
15861590
"Geometry shaders are not supported on this backend.");
1591+
if (Desc.HS || Desc.DS)
1592+
return llvm::createStringError(
1593+
std::errc::not_supported,
1594+
"Hull/Domain (tessellation) shaders are not supported on this "
1595+
"backend.");
15871596
if (Desc.Topology != PrimitiveTopology::TriangleList)
15881597
return llvm::createStringError(
15891598
std::errc::not_supported,

lib/API/VK/Device.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ static VkShaderStageFlagBits getShaderStageFlag(Stages Stage) {
206206
return VK_SHADER_STAGE_COMPUTE_BIT;
207207
case Stages::Vertex:
208208
return VK_SHADER_STAGE_VERTEX_BIT;
209+
case Stages::Hull:
210+
return VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
211+
case Stages::Domain:
212+
return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
209213
case Stages::Geometry:
210214
return VK_SHADER_STAGE_GEOMETRY_BIT;
211215
case Stages::Pixel:
@@ -224,6 +228,8 @@ static VkPrimitiveTopology getVkPrimitiveTopology(PrimitiveTopology Topology) {
224228
return VK_PRIMITIVE_TOPOLOGY_TRIANGLE_LIST;
225229
case PrimitiveTopology::PointList:
226230
return VK_PRIMITIVE_TOPOLOGY_POINT_LIST;
231+
case PrimitiveTopology::PatchList:
232+
return VK_PRIMITIVE_TOPOLOGY_PATCH_LIST;
227233
}
228234
llvm_unreachable("All PrimitiveTopology cases handled");
229235
}
@@ -1469,6 +1475,8 @@ class VulkanDevice : public offloadtest::Device {
14691475
const TraditionalRasterPipelineCreateDesc &Desc) override {
14701476
const ShaderContainer &VS = Desc.VS;
14711477
const ShaderContainer &PS = Desc.PS;
1478+
const std::optional<ShaderContainer> &HS = Desc.HS;
1479+
const std::optional<ShaderContainer> &DS = Desc.DS;
14721480
const std::optional<ShaderContainer> &GS = Desc.GS;
14731481
const llvm::ArrayRef<InputLayoutDesc> InputLayout = Desc.InputLayout;
14741482
const llvm::ArrayRef<Format> RTFormats = Desc.RTFormats;
@@ -1505,6 +1513,56 @@ class VulkanDevice : public offloadtest::Device {
15051513
ShaderStages.push_back(ShaderStage);
15061514
}
15071515

1516+
llvm::SmallVector<VkSpecializationMapEntry> HSSpecEntries;
1517+
llvm::SmallVector<char> HSSpecData;
1518+
VkSpecializationInfo HSSpecInfo = {};
1519+
if (HS) {
1520+
if (auto Err = parseSpecializationConstants(HS->SpecializationConstants,
1521+
HSSpecEntries, HSSpecData,
1522+
HSSpecInfo))
1523+
return Err;
1524+
1525+
auto HSModOrErr = createShaderModule(HS->Shader, "hull");
1526+
if (!HSModOrErr)
1527+
return HSModOrErr.takeError();
1528+
1529+
GraphicsFlags |= VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
1530+
1531+
VkPipelineShaderStageCreateInfo ShaderStage = {};
1532+
ShaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
1533+
ShaderStage.stage = VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT;
1534+
ShaderStage.module = *HSModOrErr;
1535+
ShaderStage.pName = HS->EntryPoint.c_str();
1536+
ShaderStage.pSpecializationInfo =
1537+
HS->SpecializationConstants.empty() ? nullptr : &HSSpecInfo;
1538+
ShaderStages.push_back(ShaderStage);
1539+
}
1540+
1541+
llvm::SmallVector<VkSpecializationMapEntry> DSSpecEntries;
1542+
llvm::SmallVector<char> DSSpecData;
1543+
VkSpecializationInfo DSSpecInfo = {};
1544+
if (DS) {
1545+
if (auto Err = parseSpecializationConstants(DS->SpecializationConstants,
1546+
DSSpecEntries, DSSpecData,
1547+
DSSpecInfo))
1548+
return Err;
1549+
1550+
auto DSModOrErr = createShaderModule(DS->Shader, "domain");
1551+
if (!DSModOrErr)
1552+
return DSModOrErr.takeError();
1553+
1554+
GraphicsFlags |= VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
1555+
1556+
VkPipelineShaderStageCreateInfo ShaderStage = {};
1557+
ShaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
1558+
ShaderStage.stage = VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
1559+
ShaderStage.module = *DSModOrErr;
1560+
ShaderStage.pName = DS->EntryPoint.c_str();
1561+
ShaderStage.pSpecializationInfo =
1562+
DS->SpecializationConstants.empty() ? nullptr : &DSSpecInfo;
1563+
ShaderStages.push_back(ShaderStage);
1564+
}
1565+
15081566
llvm::SmallVector<VkSpecializationMapEntry> GSSpecEntries;
15091567
llvm::SmallVector<char> GSSpecData;
15101568
VkSpecializationInfo GSSpecInfo = {};
@@ -1633,6 +1691,13 @@ class VulkanDevice : public offloadtest::Device {
16331691
VK_STRUCTURE_TYPE_PIPELINE_INPUT_ASSEMBLY_STATE_CREATE_INFO;
16341692
InputAssemblyCI.topology = getVkPrimitiveTopology(Desc.Topology);
16351693

1694+
VkPipelineTessellationStateCreateInfo TessellationCI = {};
1695+
if (Desc.PatchControlPoints) {
1696+
TessellationCI.sType =
1697+
VK_STRUCTURE_TYPE_PIPELINE_TESSELLATION_STATE_CREATE_INFO;
1698+
TessellationCI.patchControlPoints = *Desc.PatchControlPoints;
1699+
}
1700+
16361701
VkPipelineViewportStateCreateInfo ViewportCI = {};
16371702
ViewportCI.sType = VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO;
16381703
ViewportCI.viewportCount = 1;
@@ -1683,6 +1748,8 @@ class VulkanDevice : public offloadtest::Device {
16831748
PipelineCI.pStages = ShaderStages.data();
16841749
PipelineCI.pVertexInputState = &VertexInputCI;
16851750
PipelineCI.pInputAssemblyState = &InputAssemblyCI;
1751+
PipelineCI.pTessellationState =
1752+
Desc.PatchControlPoints ? &TessellationCI : nullptr;
16861753
PipelineCI.pViewportState = &ViewportCI;
16871754
PipelineCI.pRasterizationState = &RastCI;
16881755
PipelineCI.pMultisampleState = &MultisampleCI;
@@ -3229,6 +3296,7 @@ class VulkanDevice : public offloadtest::Device {
32293296
} else if (P.isTraditionalRaster()) {
32303297
TraditionalRasterPipelineCreateDesc PipelineDesc = {};
32313298
PipelineDesc.Topology = P.Bindings.Topology;
3299+
PipelineDesc.PatchControlPoints = P.Bindings.PatchControlPoints;
32323300
PipelineDesc.DSFormat = Format::D32FloatS8Uint;
32333301
for (auto &Shader : P.Shaders) {
32343302
ShaderContainer SC = {};

lib/Support/Pipeline.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ namespace yaml {
5555
void MappingTraits<offloadtest::Pipeline>::mapping(IO &I,
5656
offloadtest::Pipeline &P) {
5757
I.mapRequired("Shaders", P.Shaders);
58-
if (auto Err = P.validatePipelineKind())
59-
I.setError(llvm::toString(std::move(Err)));
6058

6159
// Runtime-specific settings.
6260
I.mapOptional("RuntimeSettings", P.Settings);
@@ -68,6 +66,12 @@ void MappingTraits<offloadtest::Pipeline>::mapping(IO &I,
6866
I.mapOptional("Bindings", P.Bindings);
6967
I.mapOptional("PushConstants", P.PushConstants);
7068

69+
// Runs here (not right after Shaders) because the tessellation topology
70+
// check reads Bindings.Topology and Bindings.PatchControlPoints. Must
71+
// still run before validateDispatchParameters, which reads P.Kind.
72+
if (auto Err = P.validatePipelineKind())
73+
I.setError(llvm::toString(std::move(Err)));
74+
7175
I.mapOptional("DispatchParameters", P.DispatchParameters);
7276
if (auto Err = P.validateDispatchParameters())
7377
I.setError(llvm::toString(std::move(Err)));
@@ -427,6 +431,7 @@ void MappingTraits<offloadtest::IOBindings>::mapping(
427431
I.mapOptional("RenderTarget", B.RenderTarget);
428432
I.mapOptional("Topology", B.Topology,
429433
offloadtest::PrimitiveTopology::TriangleList);
434+
I.mapOptional("PatchControlPoints", B.PatchControlPoints);
430435
}
431436

432437
void MappingTraits<offloadtest::PushConstantBlock>::mapping(
@@ -605,6 +610,26 @@ llvm::Error offloadtest::Pipeline::validatePipelineKind() {
605610
return llvm::createStringError("Vertex and Mesh/Amplification Shaders "
606611
"cannot be used in the same pipeline.");
607612

613+
const bool HasHS = HasShaderType[llvm::to_underlying(Stages::Hull)];
614+
const bool HasDS = HasShaderType[llvm::to_underlying(Stages::Domain)];
615+
if (HasHS != HasDS)
616+
return llvm::createStringError(
617+
"Hull and Domain shaders must be used together");
618+
619+
const bool IsTessellated = HasHS && HasDS;
620+
const bool IsPatchList = Bindings.Topology == PrimitiveTopology::PatchList;
621+
if (IsTessellated != IsPatchList)
622+
return llvm::createStringError(
623+
"Tessellation pipelines must use PatchList topology");
624+
if (IsPatchList &&
625+
(!Bindings.PatchControlPoints || *Bindings.PatchControlPoints < 1 ||
626+
*Bindings.PatchControlPoints > 32))
627+
return llvm::createStringError(
628+
"PatchList topology requires PatchControlPoints in the range 1..32.");
629+
if (!IsPatchList && Bindings.PatchControlPoints)
630+
return llvm::createStringError(
631+
"PatchControlPoints is only valid with PatchList topology.");
632+
608633
Kind = ShaderPipelineKind::TraditionalRaster;
609634
return llvm::Error::success();
610635
}

0 commit comments

Comments
 (0)