Skip to content

Commit a55e3c4

Browse files
Tests: Add mesh shader API tests for Metal backend
1 parent 5d269a7 commit a55e3c4

File tree

6 files changed

+521
-3
lines changed

6 files changed

+521
-3
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright 2026 Diligent Graphics LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
* In no event and under no legal theory, whether in tort (including negligence),
17+
* contract, or otherwise, unless required by applicable law (such as deliberate
18+
* and grossly negligent acts) or agreed to in writing, shall any Contributor be
19+
* liable for any damages, including any direct, indirect, special, incidental,
20+
* or consequential damages of any character arising as a result of this License or
21+
* out of the use or inability to use the software (including but not limited to damages
22+
* for loss of goodwill, work stoppage, computer failure or malfunction, or any and
23+
* all other commercial damages or losses), even if such Contributor has been advised
24+
* of the possibility of such damages.
25+
*/
26+
27+
#include <string>
28+
29+
namespace
30+
{
31+
32+
namespace MSL
33+
{
34+
35+
// clang-format off
36+
const std::string MeshShaderTest{
37+
R"(
38+
#include <metal_stdlib>
39+
using namespace metal;
40+
41+
struct VertexOut
42+
{
43+
float4 position [[position]];
44+
float3 color;
45+
};
46+
47+
using TriMesh = metal::mesh<VertexOut, void, 4, 2, metal::topology::triangle>;
48+
49+
[[mesh]]
50+
void MSmain(uint tid [[thread_index_in_threadgroup]],
51+
TriMesh output)
52+
{
53+
if (tid == 0)
54+
output.set_primitive_count(2);
55+
56+
const float3 colors[4] = {float3(1.0, 0.0, 0.0), float3(0.0, 1.0, 0.0),
57+
float3(0.0, 0.0, 1.0), float3(1.0, 1.0, 1.0)};
58+
59+
VertexOut v;
60+
v.position = float4(float(tid >> 1) * 2.0 - 1.0, float(tid & 1) * 2.0 - 1.0, 0.0, 1.0);
61+
v.color = colors[tid];
62+
output.set_vertex(tid, v);
63+
64+
// Triangle 0: (0, 1, 2)
65+
if (tid == 0)
66+
{
67+
output.set_index(0, 0);
68+
output.set_index(1, 1);
69+
output.set_index(2, 2);
70+
}
71+
// Triangle 1: (2, 1, 3)
72+
if (tid == 3)
73+
{
74+
output.set_index(3, 2);
75+
output.set_index(4, 1);
76+
output.set_index(5, 3);
77+
}
78+
}
79+
80+
struct FSOut
81+
{
82+
float4 color [[color(0)]];
83+
};
84+
85+
fragment FSOut PSmain(VertexOut in [[stage_in]])
86+
{
87+
FSOut out;
88+
out.color = float4(in.color, 1.0);
89+
return out;
90+
}
91+
)"
92+
};
93+
94+
const std::string AmplificationShaderTest{
95+
R"(
96+
#include <metal_stdlib>
97+
using namespace metal;
98+
99+
struct VertexOut
100+
{
101+
float4 position [[position]];
102+
float3 color;
103+
};
104+
105+
struct Payload
106+
{
107+
uint baseID;
108+
uint subIDs[8];
109+
};
110+
111+
// Object (amplification) shader
112+
[[object]]
113+
void OBJmain(uint tid [[thread_index_in_threadgroup]],
114+
uint gid [[threadgroup_position_in_grid]],
115+
object_data Payload& payload [[payload]],
116+
mesh_grid_properties mgp)
117+
{
118+
if (tid == 0)
119+
payload.baseID = gid * 8;
120+
payload.subIDs[tid] = tid;
121+
122+
if (tid == 0)
123+
mgp.set_threadgroups_per_grid(uint3(8, 1, 1));
124+
}
125+
126+
using SmallTriMesh = metal::mesh<VertexOut, void, 3, 1, metal::topology::triangle>;
127+
128+
// Mesh shader for amplification test
129+
[[mesh]]
130+
void AmpMSmain(uint gid [[threadgroup_position_in_grid]],
131+
const object_data Payload& payload [[payload]],
132+
SmallTriMesh output)
133+
{
134+
output.set_primitive_count(1);
135+
136+
uint meshletID = payload.baseID + payload.subIDs[gid];
137+
138+
const float3 colors[4] = {float3(1.0, 0.0, 0.0), float3(0.0, 1.0, 0.0),
139+
float3(0.0, 0.0, 1.0), float3(1.0, 0.0, 1.0)};
140+
141+
float2 center;
142+
center.x = (float((meshletID % 9) + 1) / 10.0) * 2.0 - 1.0;
143+
center.y = (float((meshletID / 9) + 1) / 10.0) * 2.0 - 1.0;
144+
145+
VertexOut v;
146+
v.color = colors[meshletID & 3];
147+
148+
v.position = float4(center.x, center.y + 0.09, 0.0, 1.0);
149+
output.set_vertex(0, v);
150+
151+
v.position = float4(center.x - 0.09, center.y - 0.09, 0.0, 1.0);
152+
output.set_vertex(1, v);
153+
154+
v.position = float4(center.x + 0.09, center.y - 0.09, 0.0, 1.0);
155+
output.set_vertex(2, v);
156+
157+
output.set_index(0, 2);
158+
output.set_index(1, 1);
159+
output.set_index(2, 0);
160+
}
161+
162+
struct FSOut
163+
{
164+
float4 color [[color(0)]];
165+
};
166+
167+
fragment FSOut AmpPSmain(VertexOut in [[stage_in]])
168+
{
169+
FSOut out;
170+
out.color = float4(in.color, 1.0);
171+
return out;
172+
}
173+
)"
174+
};
175+
// clang-format on
176+
177+
} // namespace MSL
178+
179+
} // namespace

Tests/DiligentCoreAPITest/src/MeshShaderTest.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2019-2023 Diligent Graphics LLC
2+
* Copyright 2019-2026 Diligent Graphics LLC
33
* Copyright 2015-2019 Egor Yusov
44
*
55
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -50,6 +50,12 @@ void MeshShaderIndirectDrawReferenceVk(ISwapChain* pSwapChain);
5050
void AmplificationShaderDrawReferenceVk(ISwapChain* pSwapChain);
5151
#endif
5252

53+
#if METAL_SUPPORTED
54+
void MeshShaderDrawReferenceMtl(ISwapChain* pSwapChain);
55+
void MeshShaderIndirectDrawReferenceMtl(ISwapChain* pSwapChain);
56+
void AmplificationShaderDrawReferenceMtl(ISwapChain* pSwapChain);
57+
#endif
58+
5359
} // namespace Testing
5460

5561
} // namespace Diligent
@@ -95,6 +101,12 @@ TEST(MeshShaderTest, DrawTriangle)
95101
break;
96102
#endif
97103

104+
#if METAL_SUPPORTED
105+
case RENDER_DEVICE_TYPE_METAL:
106+
MeshShaderDrawReferenceMtl(pSwapChain);
107+
break;
108+
#endif
109+
98110
case RENDER_DEVICE_TYPE_UNDEFINED: // to avoid empty switch
99111
default:
100112
LOG_ERROR_AND_THROW("Unsupported device type");
@@ -202,10 +214,15 @@ TEST(MeshShaderTest, DrawTriangleIndirect)
202214
break;
203215
#endif
204216

217+
#if METAL_SUPPORTED
218+
case RENDER_DEVICE_TYPE_METAL:
219+
MeshShaderIndirectDrawReferenceMtl(pSwapChain);
220+
break;
221+
#endif
222+
205223
case RENDER_DEVICE_TYPE_D3D11:
206224
case RENDER_DEVICE_TYPE_GL:
207225
case RENDER_DEVICE_TYPE_GLES:
208-
case RENDER_DEVICE_TYPE_METAL:
209226
default:
210227
LOG_ERROR_AND_THROW("Unsupported device type");
211228
}
@@ -307,6 +324,10 @@ TEST(MeshShaderTest, DrawTriangleIndirectCount)
307324
{
308325
GTEST_SKIP() << "Mesh shader is not supported by this device";
309326
}
327+
if (pDevice->GetDeviceInfo().Type == RENDER_DEVICE_TYPE_METAL)
328+
{
329+
GTEST_SKIP() << "Indirect count for mesh shaders is not supported on Metal";
330+
}
310331

311332
GPUTestingEnvironment::ScopedReset EnvironmentAutoReset;
312333

@@ -337,7 +358,6 @@ TEST(MeshShaderTest, DrawTriangleIndirectCount)
337358
case RENDER_DEVICE_TYPE_D3D11:
338359
case RENDER_DEVICE_TYPE_GL:
339360
case RENDER_DEVICE_TYPE_GLES:
340-
case RENDER_DEVICE_TYPE_METAL:
341361
default:
342362
LOG_ERROR_AND_THROW("Unsupported device type");
343363
}
@@ -470,6 +490,12 @@ TEST(MeshShaderTest, DrawTrisWithAmplificationShader)
470490
break;
471491
#endif
472492

493+
#if METAL_SUPPORTED
494+
case RENDER_DEVICE_TYPE_METAL:
495+
AmplificationShaderDrawReferenceMtl(pSwapChain);
496+
break;
497+
#endif
498+
473499
case RENDER_DEVICE_TYPE_UNDEFINED: // to avoid empty switch
474500
default:
475501
LOG_ERROR_AND_THROW("Unsupported device type");

0 commit comments

Comments
 (0)