Skip to content

Commit 91fb9bb

Browse files
committed
use f16 and better subgroup shader
1 parent 5b116d1 commit 91fb9bb

4 files changed

Lines changed: 72 additions & 69 deletions

File tree

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,20 @@
1-
import tgpu, { d, type StorageFlag, type TgpuBuffer } from 'typegpu';
2-
3-
export const ReadonlyFloats = {
4-
storage: d.arrayOf(d.f32),
5-
access: 'readonly',
6-
} as const;
7-
8-
export const MutableFloats = {
9-
storage: d.arrayOf(d.f32),
10-
access: 'mutable',
11-
} as const;
12-
13-
export const ioLayout = tgpu.bindGroupLayout({
14-
input: ReadonlyFloats,
15-
output: MutableFloats,
16-
});
17-
18-
export const weightsBiasesLayout = tgpu.bindGroupLayout({
19-
weights: ReadonlyFloats,
20-
biases: ReadonlyFloats,
21-
});
1+
import { d, type StorageFlag, type TgpuBuffer } from 'typegpu';
222

233
export interface LayerData {
244
shape: readonly [number] | readonly [number, number];
25-
buffer: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
5+
buffer: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
266
}
277

288
export interface Layer {
29-
weights: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
30-
biases: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
31-
state: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
9+
weights: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
10+
biases: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
11+
state: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
3212
}
3313

3414
export interface Network {
3515
layers: Layer[];
36-
input: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
37-
output: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
16+
input: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
17+
output: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
3818

3919
inference(data: number[]): Promise<number[]>;
4020
}

apps/typegpu-docs/src/examples/algorithms/mnist-inference/helpers.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@ function getLayerData(layer: ArrayBuffer): {
3737
};
3838
}
3939

40-
export function downloadLayers(root: TgpuRoot): Promise<[LayerData, LayerData][]> {
40+
export function downloadLayers(
41+
root: TgpuRoot,
42+
floatShcema: d.F32 | d.F16,
43+
): Promise<[LayerData, LayerData][]> {
4144
const downloadLayer = async (fileName: string): Promise<LayerData> => {
4245
const buffer = await fetch(`/TypeGPU/assets/mnist-weights/${fileName}`).then((res) =>
4346
res.arrayBuffer(),
@@ -46,7 +49,7 @@ export function downloadLayers(root: TgpuRoot): Promise<[LayerData, LayerData][]
4649
const { shape, data } = getLayerData(buffer);
4750

4851
const layerBuffer = root
49-
.createBuffer(d.arrayOf(d.f32, data.length), [...data])
52+
.createBuffer(d.arrayOf(floatShcema, data.length), [...data])
5053
.$usage('storage');
5154

5255
return {

apps/typegpu-docs/src/examples/algorithms/mnist-inference/index.ts

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,87 @@
11
import tgpu, { d, std } from 'typegpu';
2-
import { ioLayout, type LayerData, type Network, weightsBiasesLayout } from './data.ts';
2+
import type { LayerData, Network } from './data.ts';
33
import { downloadLayers } from './helpers.ts';
44
import { defineControls } from '../../common/defineControls.ts';
55

66
const SIZE = 28;
77

88
const root = await tgpu.init({
9-
device: {
10-
optionalFeatures: ['timestamp-query', 'subgroups'],
11-
},
9+
device: { optionalFeatures: ['timestamp-query', 'subgroups', 'shader-f16'] },
1210
});
1311
const hasTimestampQuery = root.enabledFeatures.has('timestamp-query');
1412
const hasSubgroups = root.enabledFeatures.has('subgroups');
13+
const hasShaderF16 = root.enabledFeatures.has('shader-f16');
1514
let useSubgroups = hasSubgroups;
1615

16+
const float = hasShaderF16 ? d.f16 : d.f32;
17+
18+
const ioLayout = tgpu.bindGroupLayout({
19+
input: { storage: d.arrayOf(float) },
20+
output: {
21+
storage: d.arrayOf(float),
22+
access: 'mutable',
23+
},
24+
});
25+
26+
const weightsBiasesLayout = tgpu.bindGroupLayout({
27+
weights: { storage: d.arrayOf(float) },
28+
biases: { storage: d.arrayOf(float) },
29+
});
30+
1731
const canvasData = Array.from({ length: SIZE ** 2 }, () => 0);
1832

1933
// Shaders
2034

21-
const relu = tgpu.fn([d.f32], d.f32)((x) => std.max(0, x));
35+
function relu(x: number): number {
36+
'use gpu';
37+
return std.max(0, x);
38+
}
2239

2340
const defaultCompute = tgpu.computeFn({
24-
in: {
25-
gid: d.builtin.globalInvocationId,
26-
},
27-
workgroupSize: [1],
41+
in: { gid: d.builtin.globalInvocationId },
42+
workgroupSize: [64],
2843
})(({ gid }) => {
44+
const i = gid.x;
2945
const inputSize = ioLayout.$.input.length;
46+
if (i >= inputSize) {
47+
return;
48+
}
3049

31-
const i = gid.x;
3250
const weightsOffset = i * inputSize;
33-
let sum = d.f32();
51+
let sum = float();
3452

35-
for (let j = d.u32(); j < inputSize; j++) {
53+
for (let j = d.u32(0); j < inputSize; j++) {
3654
sum = std.fma(ioLayout.$.input[j], weightsBiasesLayout.$.weights[weightsOffset + j], sum);
3755
}
3856

3957
const total = sum + weightsBiasesLayout.$.biases[i];
4058
ioLayout.$.output[i] = relu(total);
4159
});
4260

43-
const workgroupSize = tgpu.const(d.u32, 128);
4461
const subgroupCompute = tgpu.computeFn({
4562
in: {
46-
lid: d.builtin.localInvocationId,
4763
wid: d.builtin.workgroupId,
4864
sid: d.builtin.subgroupInvocationId,
49-
ssize: d.builtin.subgroupSize,
65+
sgid: d.builtin.subgroupId,
66+
nsg: d.builtin.numSubgroups,
5067
},
51-
workgroupSize: [128],
52-
})(({ lid, wid, sid, ssize }) => {
53-
const subgroupId = d.u32(lid.x / ssize);
54-
const outputsPerWG = d.u32(workgroupSize.$ / ssize);
55-
const neuronIndex = wid.x * outputsPerWG + subgroupId;
56-
68+
workgroupSize: [64],
69+
})(({ wid, sid, sgid, nsg }) => {
5770
const outLen = ioLayout.$.output.length;
71+
const inputSize = ioLayout.$.input.length;
72+
73+
const neuronIndex = wid.x * nsg + sgid;
5874
const valid = neuronIndex < outLen;
5975

60-
const inputSize = ioLayout.$.input.length;
76+
// Actual number of active lanes in this subgroup.
77+
const laneCount = std.subgroupAdd(1);
6178

62-
let partial = d.f32();
79+
let partial = float(0);
6380

6481
if (valid) {
6582
const weightsOffset = neuronIndex * inputSize;
66-
for (let j = sid; j < inputSize; j += ssize) {
83+
84+
for (let j = sid; j < inputSize; j += laneCount) {
6785
partial = std.fma(
6886
ioLayout.$.input[j],
6987
weightsBiasesLayout.$.weights[weightsOffset + j],
@@ -74,7 +92,7 @@ const subgroupCompute = tgpu.computeFn({
7492

7593
const sum = std.subgroupAdd(partial);
7694

77-
if (valid && sid === 0) {
95+
if (valid && std.subgroupElect()) {
7896
ioLayout.$.output[neuronIndex] = relu(sum + weightsBiasesLayout.$.biases[neuronIndex]);
7997
}
8098
});
@@ -107,11 +125,11 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
107125
return {
108126
weights: weights.buffer,
109127
biases: biases.buffer,
110-
state: root.createBuffer(d.arrayOf(d.f32, biases.shape[0])).$usage('storage'),
128+
state: root.createBuffer(d.arrayOf(float, biases.shape[0])).$usage('storage'),
111129
};
112130
});
113131

114-
const input = root.createBuffer(d.arrayOf(d.f32, layers[0][0].shape[0])).$usage('storage');
132+
const input = root.createBuffer(d.arrayOf(float, layers[0][0].shape[0])).$usage('storage');
115133
const output = buffers[buffers.length - 1].state;
116134

117135
const ioBindGroups = buffers.map((_, i) =>
@@ -180,7 +198,7 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
180198
};
181199
}
182200

183-
const network = createNetwork(await downloadLayers(root));
201+
const network = createNetwork(await downloadLayers(root, float));
184202

185203
// #region Example controls and cleanup
186204

@@ -386,7 +404,7 @@ export const controls = defineControls({
386404
'Test Resolution': import.meta.env.DEV && {
387405
onButtonClick: () =>
388406
[defaultCompute, subgroupCompute]
389-
.map((fn) => tgpu.resolve([fn], { enableExtensions: ['subgroups'] }))
407+
.map((fn) => tgpu.resolve([fn], { enableExtensions: ['subgroups', 'f16'] }))
390408
.map((r) => root.device.createShaderModule({ code: r })),
391409
},
392410
});

apps/typegpu-docs/tests/individual-example-tests/mnist-inference.test.ts

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ describe('mnist inference example', () => {
2424

2525
expect(shaderCodes).toMatchInlineSnapshot(`
2626
"enable subgroups;
27+
enable f16;
2728
2829
@group(0) @binding(0) var<storage, read> input: array<f32>;
2930
@@ -37,9 +38,12 @@ describe('mnist inference example', () => {
3738
return max(0f, x);
3839
}
3940
40-
@compute @workgroup_size(1) fn defaultCompute(@builtin(global_invocation_id) gid: vec3u) {
41-
let inputSize = arrayLength(&input);
41+
@compute @workgroup_size(64) fn defaultCompute(@builtin(global_invocation_id) gid: vec3u) {
4242
let i = gid.x;
43+
let inputSize = arrayLength(&input);
44+
if ((i >= inputSize)) {
45+
return;
46+
}
4347
let weightsOffset = (i * inputSize);
4448
var sum = 0f;
4549
for (var j = 0u; (j < inputSize); j++) {
@@ -50,8 +54,7 @@ describe('mnist inference example', () => {
5054
}
5155
5256
enable subgroups;
53-
54-
const workgroupSize: u32 = 128u;
57+
enable f16;
5558
5659
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
5760
@@ -65,22 +68,21 @@ describe('mnist inference example', () => {
6568
return max(0f, x);
6669
}
6770
68-
@compute @workgroup_size(128) fn subgroupCompute(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u, @builtin(subgroup_invocation_id) sid: u32, @builtin(subgroup_size) ssize: u32) {
69-
let subgroupId = u32((f32(lid.x) / f32(ssize)));
70-
let outputsPerWG = u32((f32(workgroupSize) / f32(ssize)));
71-
let neuronIndex = ((wid.x * outputsPerWG) + subgroupId);
71+
@compute @workgroup_size(64) fn subgroupCompute(@builtin(workgroup_id) wid: vec3u, @builtin(subgroup_invocation_id) sid: u32, @builtin(subgroup_id) sgid: u32, @builtin(num_subgroups) nsg: u32) {
7272
let outLen = arrayLength(&output);
73-
let valid = (neuronIndex < outLen);
7473
let inputSize = arrayLength(&input);
74+
let neuronIndex = ((wid.x * nsg) + sgid);
75+
let valid = (neuronIndex < outLen);
76+
let laneCount = subgroupAdd(1);
7577
var partial = 0f;
7678
if (valid) {
7779
let weightsOffset = (neuronIndex * inputSize);
78-
for (var j = sid; (j < inputSize); j += ssize) {
80+
for (var j = sid; (j < inputSize); j += u32(laneCount)) {
7981
partial = fma(input[j], weights[(weightsOffset + j)], partial);
8082
}
8183
}
8284
let sum = subgroupAdd(partial);
83-
if ((valid && (sid == 0u))) {
85+
if ((valid && subgroupElect())) {
8486
output[neuronIndex] = relu((sum + biases[neuronIndex]));
8587
}
8688
}"

0 commit comments

Comments
 (0)