Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 7 additions & 27 deletions apps/typegpu-docs/src/examples/algorithms/mnist-inference/data.ts
Original file line number Diff line number Diff line change
@@ -1,40 +1,20 @@
import tgpu, { d, type StorageFlag, type TgpuBuffer } from 'typegpu';

export const ReadonlyFloats = {
storage: d.arrayOf(d.f32),
access: 'readonly',
} as const;

export const MutableFloats = {
storage: d.arrayOf(d.f32),
access: 'mutable',
} as const;

export const ioLayout = tgpu.bindGroupLayout({
input: ReadonlyFloats,
output: MutableFloats,
});

export const weightsBiasesLayout = tgpu.bindGroupLayout({
weights: ReadonlyFloats,
biases: ReadonlyFloats,
});
import { d, type StorageFlag, type TgpuBuffer } from 'typegpu';

export interface LayerData {
shape: readonly [number] | readonly [number, number];
buffer: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
buffer: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
}

export interface Layer {
weights: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
biases: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
state: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
weights: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
biases: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
state: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
}

export interface Network {
layers: Layer[];
input: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
output: TgpuBuffer<d.WgslArray<d.F32>> & StorageFlag;
input: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;
output: TgpuBuffer<d.WgslArray<d.F32 | d.F16>> & StorageFlag;

inference(data: number[]): Promise<number[]>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ function getLayerData(layer: ArrayBuffer): {
};
}

export function downloadLayers(root: TgpuRoot): Promise<[LayerData, LayerData][]> {
export function downloadLayers(
root: TgpuRoot,
floatSchema: d.F32 | d.F16,
): Promise<[LayerData, LayerData][]> {
Comment on lines +40 to +43
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter name floatShcema appears to be a typo; rename to floatSchema to avoid propagating a misspelled identifier through the API.

Copilot uses AI. Check for mistakes.
Comment on lines +40 to +43
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in parameter name floatShcema (should be floatSchema). Keeping the misspelling makes the API harder to read/search and increases the chance of propagating the typo to call sites.

Copilot uses AI. Check for mistakes.
const downloadLayer = async (fileName: string): Promise<LayerData> => {
const buffer = await fetch(`/TypeGPU/assets/mnist-weights/${fileName}`).then((res) =>
res.arrayBuffer(),
Expand All @@ -46,7 +49,7 @@ export function downloadLayers(root: TgpuRoot): Promise<[LayerData, LayerData][]
const { shape, data } = getLayerData(buffer);

const layerBuffer = root
.createBuffer(d.arrayOf(d.f32, data.length), [...data])
.createBuffer(d.arrayOf(floatSchema, data.length), [...data])
.$usage('storage');

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

<div class="info">
<div>Subgroups: <span id="subgroups-status">-</span></div>
<div>Precision: <span id="precision-status">-</span></div>
<div>Inference: <span id="inference-time">-</span></div>
</div>
</div>
Expand Down
99 changes: 65 additions & 34 deletions apps/typegpu-docs/src/examples/algorithms/mnist-inference/index.ts
Original file line number Diff line number Diff line change
@@ -1,69 +1,86 @@
import tgpu, { d, std } from 'typegpu';
import { ioLayout, type LayerData, type Network, weightsBiasesLayout } from './data.ts';
import type { LayerData, Network } from './data.ts';
import { downloadLayers } from './helpers.ts';
import { defineControls } from '../../common/defineControls.ts';

const SIZE = 28;
const WORKGROUP_SIZE = 64;

const root = await tgpu.init({
device: {
optionalFeatures: ['timestamp-query', 'subgroups'],
},
device: { optionalFeatures: ['timestamp-query', 'subgroups', 'shader-f16'] },
});
const hasTimestampQuery = root.enabledFeatures.has('timestamp-query');
const hasSubgroups = root.enabledFeatures.has('subgroups');
const hasShaderF16 = root.enabledFeatures.has('shader-f16');
let useSubgroups = hasSubgroups;

const float = hasShaderF16 ? d.f16 : d.f32;

const ioLayout = tgpu.bindGroupLayout({
input: { storage: d.arrayOf(float) },
output: {
storage: d.arrayOf(float),
access: 'mutable',
},
});

const weightsBiasesLayout = tgpu.bindGroupLayout({
weights: { storage: d.arrayOf(float) },
biases: { storage: d.arrayOf(float) },
});

const canvasData = Array.from({ length: SIZE ** 2 }, () => 0);

// Shaders

const relu = tgpu.fn([d.f32], d.f32)((x) => std.max(0, x));
function relu(x: number): number {
'use gpu';
return std.max(0, x);
}

const defaultCompute = tgpu.computeFn({
in: {
gid: d.builtin.globalInvocationId,
},
workgroupSize: [1],
in: { gid: d.builtin.globalInvocationId },
workgroupSize: [WORKGROUP_SIZE],
})(({ gid }) => {
const inputSize = ioLayout.$.input.length;

const i = gid.x;
const outLen = ioLayout.$.output.length;
if (i >= outLen) {
return;
}

const inputSize = ioLayout.$.input.length;
const weightsOffset = i * inputSize;
let sum = d.f32();
let sum = float();

for (let j = d.u32(); j < inputSize; j++) {
for (let j = d.u32(0); j < inputSize; j++) {
sum = std.fma(ioLayout.$.input[j], weightsBiasesLayout.$.weights[weightsOffset + j], sum);
}

const total = sum + weightsBiasesLayout.$.biases[i];
ioLayout.$.output[i] = relu(total);
});

const workgroupSize = tgpu.const(d.u32, 128);
const subgroupCompute = tgpu.computeFn({
in: {
lid: d.builtin.localInvocationId,
wid: d.builtin.workgroupId,
sid: d.builtin.subgroupInvocationId,
ssize: d.builtin.subgroupSize,
sgid: d.builtin.subgroupId,
subgroupSize: d.builtin.subgroupSize,
},
workgroupSize: [128],
})(({ lid, wid, sid, ssize }) => {
const subgroupId = d.u32(lid.x / ssize);
const outputsPerWG = d.u32(workgroupSize.$ / ssize);
const neuronIndex = wid.x * outputsPerWG + subgroupId;

workgroupSize: [WORKGROUP_SIZE],
})(({ wid, sid, sgid, subgroupSize }) => {
const outLen = ioLayout.$.output.length;
const valid = neuronIndex < outLen;

const inputSize = ioLayout.$.input.length;

let partial = d.f32();
const neuronIndex = wid.x;
const valid = sgid === 0 && neuronIndex < outLen;

let partial = float(0);

if (valid) {
const weightsOffset = neuronIndex * inputSize;
for (let j = sid; j < inputSize; j += ssize) {

for (let j = sid; j < inputSize; j += subgroupSize) {
partial = std.fma(
ioLayout.$.input[j],
weightsBiasesLayout.$.weights[weightsOffset + j],
Expand All @@ -74,7 +91,7 @@ const subgroupCompute = tgpu.computeFn({

const sum = std.subgroupAdd(partial);

if (valid && sid === 0) {
if (valid && std.subgroupElect()) {
ioLayout.$.output[neuronIndex] = relu(sum + weightsBiasesLayout.$.biases[neuronIndex]);
}
});
Expand Down Expand Up @@ -107,11 +124,11 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
return {
weights: weights.buffer,
biases: biases.buffer,
state: root.createBuffer(d.arrayOf(d.f32, biases.shape[0])).$usage('storage'),
state: root.createBuffer(d.arrayOf(float, biases.shape[0])).$usage('storage'),
};
});

const input = root.createBuffer(d.arrayOf(d.f32, layers[0][0].shape[0])).$usage('storage');
const input = root.createBuffer(d.arrayOf(float, layers[0][0].shape[0])).$usage('storage');
const output = buffers[buffers.length - 1].state;

const ioBindGroups = buffers.map((_, i) =>
Expand All @@ -137,7 +154,8 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
}
input.write(data);

const pipeline = useSubgroups && pipelines.subgroup ? pipelines.subgroup : pipelines.default;
const subgroupPipeline = useSubgroups ? pipelines.subgroup : null;
const pipeline = subgroupPipeline ?? pipelines.default;

// Run the network
for (let i = 0; i < buffers.length; i++) {
Expand All @@ -155,7 +173,10 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
boundPipeline = boundPipeline.withTimestampWrites(descriptor);
}

boundPipeline.dispatchWorkgroups(buffers[i].biases.dataType.elementCount);
const outputCount = buffers[i].biases.dataType.elementCount;
boundPipeline.dispatchWorkgroups(
subgroupPipeline ? outputCount : Math.ceil(outputCount / WORKGROUP_SIZE),
);
Comment on lines +176 to +179
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dispatchWorkgroups uses outputCount when the subgroup pipeline is selected, but subgroupCompute computes num_subgroups outputs per workgroup (neuronIndex = wid.x * nsg + sgid). This over-dispatches workgroups by a factor of nsg (e.g., 2x for 64 threads with 32-wide subgroups), doing unnecessary work for larger layers. Consider either dispatching ceil(outputCount / outputsPerWorkgroup) (if you can determine outputsPerWorkgroup) or adjusting the shader/work mapping so each workgroup corresponds to exactly one output when dispatch count must be outputCount.

Copilot uses AI. Check for mistakes.
}

if (querySet?.available) {
Expand All @@ -180,7 +201,7 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
};
}

const network = createNetwork(await downloadLayers(root));
const network = createNetwork(await downloadLayers(root, float));

// #region Example controls and cleanup

Expand All @@ -189,6 +210,7 @@ const context = canvas.getContext('2d') as CanvasRenderingContext2D;

const bars = Array.from(document.querySelectorAll('.bar')) as HTMLDivElement[];
const subgroupsEl = document.getElementById('subgroups-status') as HTMLSpanElement;
const precisionEl = document.getElementById('precision-status') as HTMLSpanElement;
const inferenceTimeEl = document.getElementById('inference-time') as HTMLSpanElement;

const uiState = {
Expand Down Expand Up @@ -252,6 +274,8 @@ function updateSubgroupsStatus() {
}

updateSubgroupsStatus();
precisionEl.textContent = hasShaderF16 ? 'f16' : 'f32';
precisionEl.className = hasShaderF16 ? 'enabled' : 'disabled';

run();

Expand Down Expand Up @@ -385,8 +409,15 @@ export const controls = defineControls({
},
'Test Resolution': import.meta.env.DEV && {
onButtonClick: () =>
[defaultCompute, subgroupCompute]
.map((fn) => tgpu.resolve([fn], { enableExtensions: ['subgroups'] }))
[defaultCompute, ...(hasSubgroups ? [subgroupCompute] : [])]
.map((fn) =>
tgpu.resolve([fn], {
enableExtensions: [
...(hasSubgroups ? ['subgroups' as const] : []),
...(hasShaderF16 ? ['f16' as const] : []),
],
}),
)
.map((r) => root.device.createShaderModule({ code: r })),
Comment thread
reczkok marked this conversation as resolved.
},
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ import { mockMnistWeights } from './utils/commonMocks.ts';
describe('mnist inference example', () => {
setupCommonMocks();

it('should produce valid code', async ({ device }) => {
it('should produce valid code', async ({ adapter, device }) => {
for (const feature of ['subgroups', 'shader-f16'] satisfies GPUFeatureName[]) {
adapter.features.add(feature);
(device.features as Set<GPUFeatureName>).add(feature);
}

const shaderCodes = await runExampleTest(
{
category: 'algorithms',
Expand All @@ -24,24 +29,29 @@ describe('mnist inference example', () => {

expect(shaderCodes).toMatchInlineSnapshot(`
"enable subgroups;
enable f16;

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f16>;

@group(1) @binding(0) var<storage, read> weights: array<f32>;
@group(0) @binding(0) var<storage, read> input: array<f16>;

@group(1) @binding(1) var<storage, read> biases: array<f32>;
@group(1) @binding(0) var<storage, read> weights: array<f16>;

@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(1) @binding(1) var<storage, read> biases: array<f16>;

fn relu(x: f32) -> f32 {
return max(0f, x);
fn relu(x: f16) -> f16 {
return max(0h, x);
}

@compute @workgroup_size(1) fn defaultCompute(@builtin(global_invocation_id) gid: vec3u) {
let inputSize = arrayLength(&input);
@compute @workgroup_size(64) fn defaultCompute(@builtin(global_invocation_id) gid: vec3u) {
let i = gid.x;
let outLen = arrayLength(&output);
if ((i >= outLen)) {
return;
}
let inputSize = arrayLength(&input);
let weightsOffset = (i * inputSize);
var sum = 0f;
var sum = 0h;
for (var j = 0u; (j < inputSize); j++) {
sum = fma(input[j], weights[(weightsOffset + j)], sum);
}
Expand All @@ -50,37 +60,34 @@ describe('mnist inference example', () => {
}

enable subgroups;
enable f16;

const workgroupSize: u32 = 128u;

@group(0) @binding(1) var<storage, read_write> output: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f16>;

@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(0) var<storage, read> input: array<f16>;

@group(1) @binding(0) var<storage, read> weights: array<f32>;
@group(1) @binding(0) var<storage, read> weights: array<f16>;

@group(1) @binding(1) var<storage, read> biases: array<f32>;
@group(1) @binding(1) var<storage, read> biases: array<f16>;

fn relu(x: f32) -> f32 {
return max(0f, x);
fn relu(x: f16) -> f16 {
return max(0h, x);
}

@compute @workgroup_size(128) fn subgroupCompute(@builtin(local_invocation_id) lid: vec3u, @builtin(workgroup_id) wid: vec3u, @builtin(subgroup_invocation_id) sid: u32, @builtin(subgroup_size) ssize: u32) {
let subgroupId = u32((f32(lid.x) / f32(ssize)));
let outputsPerWG = u32((f32(workgroupSize) / f32(ssize)));
let neuronIndex = ((wid.x * outputsPerWG) + subgroupId);
@compute @workgroup_size(64) fn subgroupCompute(@builtin(workgroup_id) wid: vec3u, @builtin(subgroup_invocation_id) sid: u32, @builtin(subgroup_id) sgid: u32, @builtin(subgroup_size) subgroupSize: u32) {
let outLen = arrayLength(&output);
let valid = (neuronIndex < outLen);
let inputSize = arrayLength(&input);
var partial = 0f;
let neuronIndex = wid.x;
let valid = ((sgid == 0u) && (neuronIndex < outLen));
var partial = 0h;
if (valid) {
let weightsOffset = (neuronIndex * inputSize);
for (var j = sid; (j < inputSize); j += ssize) {
for (var j = sid; (j < inputSize); j += subgroupSize) {
partial = fma(input[j], weights[(weightsOffset + j)], partial);
}
}
let sum = subgroupAdd(partial);
if ((valid && (sid == 0u))) {
if ((valid && subgroupElect())) {
output[neuronIndex] = relu((sum + biases[neuronIndex]));
}
}"
Expand Down
Loading