11import tgpu , { d , std } from 'typegpu' ;
2- import { ioLayout , type LayerData , type Network , weightsBiasesLayout } from './data.ts' ;
2+ import type { LayerData , Network } from './data.ts' ;
33import { downloadLayers } from './helpers.ts' ;
44import { defineControls } from '../../common/defineControls.ts' ;
55
66const SIZE = 28 ;
77
88const root = await tgpu . init ( {
9- device : {
10- optionalFeatures : [ 'timestamp-query' , 'subgroups' ] ,
11- } ,
9+ device : { optionalFeatures : [ 'timestamp-query' , 'subgroups' , 'shader-f16' ] } ,
1210} ) ;
1311const hasTimestampQuery = root . enabledFeatures . has ( 'timestamp-query' ) ;
1412const hasSubgroups = root . enabledFeatures . has ( 'subgroups' ) ;
13+ const hasShaderF16 = root . enabledFeatures . has ( 'shader-f16' ) ;
1514let useSubgroups = hasSubgroups ;
1615
16+ const float = hasShaderF16 ? d . f16 : d . f32 ;
17+
18+ const ioLayout = tgpu . bindGroupLayout ( {
19+ input : { storage : d . arrayOf ( float ) } ,
20+ output : {
21+ storage : d . arrayOf ( float ) ,
22+ access : 'mutable' ,
23+ } ,
24+ } ) ;
25+
26+ const weightsBiasesLayout = tgpu . bindGroupLayout ( {
27+ weights : { storage : d . arrayOf ( float ) } ,
28+ biases : { storage : d . arrayOf ( float ) } ,
29+ } ) ;
30+
1731const canvasData = Array . from ( { length : SIZE ** 2 } , ( ) => 0 ) ;
1832
1933// Shaders
2034
21- const relu = tgpu . fn ( [ d . f32 ] , d . f32 ) ( ( x ) => std . max ( 0 , x ) ) ;
35+ function relu ( x : number ) : number {
36+ 'use gpu' ;
37+ return std . max ( 0 , x ) ;
38+ }
2239
2340const defaultCompute = tgpu . computeFn ( {
24- in : {
25- gid : d . builtin . globalInvocationId ,
26- } ,
27- workgroupSize : [ 1 ] ,
41+ in : { gid : d . builtin . globalInvocationId } ,
42+ workgroupSize : [ 64 ] ,
2843} ) ( ( { gid } ) => {
44+ const i = gid . x ;
2945 const inputSize = ioLayout . $ . input . length ;
46+ if ( i >= inputSize ) {
47+ return ;
48+ }
3049
31- const i = gid . x ;
3250 const weightsOffset = i * inputSize ;
33- let sum = d . f32 ( ) ;
51+ let sum = float ( ) ;
3452
35- for ( let j = d . u32 ( ) ; j < inputSize ; j ++ ) {
53+ for ( let j = d . u32 ( 0 ) ; j < inputSize ; j ++ ) {
3654 sum = std . fma ( ioLayout . $ . input [ j ] , weightsBiasesLayout . $ . weights [ weightsOffset + j ] , sum ) ;
3755 }
3856
3957 const total = sum + weightsBiasesLayout . $ . biases [ i ] ;
4058 ioLayout . $ . output [ i ] = relu ( total ) ;
4159} ) ;
4260
43- const workgroupSize = tgpu . const ( d . u32 , 128 ) ;
4461const subgroupCompute = tgpu . computeFn ( {
4562 in : {
46- lid : d . builtin . localInvocationId ,
4763 wid : d . builtin . workgroupId ,
4864 sid : d . builtin . subgroupInvocationId ,
49- ssize : d . builtin . subgroupSize ,
65+ sgid : d . builtin . subgroupId ,
66+ nsg : d . builtin . numSubgroups ,
5067 } ,
51- workgroupSize : [ 128 ] ,
52- } ) ( ( { lid, wid, sid, ssize } ) => {
53- const subgroupId = d . u32 ( lid . x / ssize ) ;
54- const outputsPerWG = d . u32 ( workgroupSize . $ / ssize ) ;
55- const neuronIndex = wid . x * outputsPerWG + subgroupId ;
56-
68+ workgroupSize : [ 64 ] ,
69+ } ) ( ( { wid, sid, sgid, nsg } ) => {
5770 const outLen = ioLayout . $ . output . length ;
71+ const inputSize = ioLayout . $ . input . length ;
72+
73+ const neuronIndex = wid . x * nsg + sgid ;
5874 const valid = neuronIndex < outLen ;
5975
60- const inputSize = ioLayout . $ . input . length ;
76+ // Actual number of active lanes in this subgroup.
77+ const laneCount = std . subgroupAdd ( 1 ) ;
6178
62- let partial = d . f32 ( ) ;
79+ let partial = float ( 0 ) ;
6380
6481 if ( valid ) {
6582 const weightsOffset = neuronIndex * inputSize ;
66- for ( let j = sid ; j < inputSize ; j += ssize ) {
83+
84+ for ( let j = sid ; j < inputSize ; j += laneCount ) {
6785 partial = std . fma (
6886 ioLayout . $ . input [ j ] ,
6987 weightsBiasesLayout . $ . weights [ weightsOffset + j ] ,
@@ -74,7 +92,7 @@ const subgroupCompute = tgpu.computeFn({
7492
7593 const sum = std . subgroupAdd ( partial ) ;
7694
77- if ( valid && sid === 0 ) {
95+ if ( valid && std . subgroupElect ( ) ) {
7896 ioLayout . $ . output [ neuronIndex ] = relu ( sum + weightsBiasesLayout . $ . biases [ neuronIndex ] ) ;
7997 }
8098} ) ;
@@ -107,11 +125,11 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
107125 return {
108126 weights : weights . buffer ,
109127 biases : biases . buffer ,
110- state : root . createBuffer ( d . arrayOf ( d . f32 , biases . shape [ 0 ] ) ) . $usage ( 'storage' ) ,
128+ state : root . createBuffer ( d . arrayOf ( float , biases . shape [ 0 ] ) ) . $usage ( 'storage' ) ,
111129 } ;
112130 } ) ;
113131
114- const input = root . createBuffer ( d . arrayOf ( d . f32 , layers [ 0 ] [ 0 ] . shape [ 0 ] ) ) . $usage ( 'storage' ) ;
132+ const input = root . createBuffer ( d . arrayOf ( float , layers [ 0 ] [ 0 ] . shape [ 0 ] ) ) . $usage ( 'storage' ) ;
115133 const output = buffers [ buffers . length - 1 ] . state ;
116134
117135 const ioBindGroups = buffers . map ( ( _ , i ) =>
@@ -180,7 +198,7 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
180198 } ;
181199}
182200
183- const network = createNetwork ( await downloadLayers ( root ) ) ;
201+ const network = createNetwork ( await downloadLayers ( root , float ) ) ;
184202
185203// #region Example controls and cleanup
186204
@@ -386,7 +404,7 @@ export const controls = defineControls({
386404 'Test Resolution' : import . meta. env . DEV && {
387405 onButtonClick : ( ) =>
388406 [ defaultCompute , subgroupCompute ]
389- . map ( ( fn ) => tgpu . resolve ( [ fn ] , { enableExtensions : [ 'subgroups' ] } ) )
407+ . map ( ( fn ) => tgpu . resolve ( [ fn ] , { enableExtensions : [ 'subgroups' , 'f16' ] } ) )
390408 . map ( ( r ) => root . device . createShaderModule ( { code : r } ) ) ,
391409 } ,
392410} ) ;
0 commit comments