Skip to content

Commit 82d84e8

Browse files
committed
do not over-dispatch
1 parent 7d1a682 commit 82d84e8

2 files changed

Lines changed: 17 additions & 10 deletions

File tree

apps/typegpu-docs/src/examples/algorithms/mnist-inference/index.ts

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { downloadLayers } from './helpers.ts';
44
import { defineControls } from '../../common/defineControls.ts';
55

66
const SIZE = 28;
7+
const WORKGROUP_SIZE = 64;
78

89
const root = await tgpu.init({
910
device: { optionalFeatures: ['timestamp-query', 'subgroups', 'shader-f16'] },
@@ -39,14 +40,15 @@ function relu(x: number): number {
3940

4041
const defaultCompute = tgpu.computeFn({
4142
in: { gid: d.builtin.globalInvocationId },
42-
workgroupSize: [64],
43+
workgroupSize: [WORKGROUP_SIZE],
4344
})(({ gid }) => {
4445
const i = gid.x;
45-
const inputSize = ioLayout.$.input.length;
46-
if (i >= inputSize) {
46+
const outLen = ioLayout.$.output.length;
47+
if (i >= outLen) {
4748
return;
4849
}
4950

51+
const inputSize = ioLayout.$.input.length;
5052
const weightsOffset = i * inputSize;
5153
let sum = float();
5254

@@ -65,7 +67,7 @@ const subgroupCompute = tgpu.computeFn({
6567
sgid: d.builtin.subgroupId,
6668
nsg: d.builtin.numSubgroups,
6769
},
68-
workgroupSize: [64],
70+
workgroupSize: [WORKGROUP_SIZE],
6971
})(({ wid, sid, sgid, nsg }) => {
7072
const outLen = ioLayout.$.output.length;
7173
const inputSize = ioLayout.$.input.length;
@@ -155,7 +157,8 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
155157
}
156158
input.write(data);
157159

158-
const pipeline = useSubgroups && pipelines.subgroup ? pipelines.subgroup : pipelines.default;
160+
const subgroupPipeline = useSubgroups ? pipelines.subgroup : null;
161+
const pipeline = subgroupPipeline ?? pipelines.default;
159162

160163
// Run the network
161164
for (let i = 0; i < buffers.length; i++) {
@@ -173,7 +176,10 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
173176
boundPipeline = boundPipeline.withTimestampWrites(descriptor);
174177
}
175178

176-
boundPipeline.dispatchWorkgroups(buffers[i].biases.dataType.elementCount);
179+
const outputCount = buffers[i].biases.dataType.elementCount;
180+
boundPipeline.dispatchWorkgroups(
181+
subgroupPipeline ? outputCount : Math.ceil(outputCount / WORKGROUP_SIZE),
182+
);
177183
}
178184

179185
if (querySet?.available) {

apps/typegpu-docs/tests/individual-example-tests/mnist-inference.test.ts

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,24 +26,25 @@ describe('mnist inference example', () => {
2626
"enable subgroups;
2727
enable f16;
2828
29+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
30+
2931
@group(0) @binding(0) var<storage, read> input: array<f32>;
3032
3133
@group(1) @binding(0) var<storage, read> weights: array<f32>;
3234
3335
@group(1) @binding(1) var<storage, read> biases: array<f32>;
3436
35-
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
36-
3737
fn relu(x: f32) -> f32 {
3838
return max(0f, x);
3939
}
4040
4141
@compute @workgroup_size(64) fn defaultCompute(@builtin(global_invocation_id) gid: vec3u) {
4242
let i = gid.x;
43-
let inputSize = arrayLength(&input);
44-
if ((i >= inputSize)) {
43+
let outLen = arrayLength(&output);
44+
if ((i >= outLen)) {
4545
return;
4646
}
47+
let inputSize = arrayLength(&input);
4748
let weightsOffset = (i * inputSize);
4849
var sum = 0f;
4950
for (var j = 0u; (j < inputSize); j++) {

0 commit comments

Comments
 (0)