@@ -4,6 +4,7 @@ import { downloadLayers } from './helpers.ts';
44import { defineControls } from '../../common/defineControls.ts' ;
55
66const SIZE = 28 ;
7+ const WORKGROUP_SIZE = 64 ;
78
89const root = await tgpu . init ( {
910 device : { optionalFeatures : [ 'timestamp-query' , 'subgroups' , 'shader-f16' ] } ,
@@ -39,14 +40,15 @@ function relu(x: number): number {
3940
4041const defaultCompute = tgpu . computeFn ( {
4142 in : { gid : d . builtin . globalInvocationId } ,
42- workgroupSize : [ 64 ] ,
43+ workgroupSize : [ WORKGROUP_SIZE ] ,
4344} ) ( ( { gid } ) => {
4445 const i = gid . x ;
45- const inputSize = ioLayout . $ . input . length ;
46- if ( i >= inputSize ) {
46+ const outLen = ioLayout . $ . output . length ;
47+ if ( i >= outLen ) {
4748 return ;
4849 }
4950
51+ const inputSize = ioLayout . $ . input . length ;
5052 const weightsOffset = i * inputSize ;
5153 let sum = float ( ) ;
5254
@@ -65,7 +67,7 @@ const subgroupCompute = tgpu.computeFn({
6567 sgid : d . builtin . subgroupId ,
6668 nsg : d . builtin . numSubgroups ,
6769 } ,
68- workgroupSize : [ 64 ] ,
70+ workgroupSize : [ WORKGROUP_SIZE ] ,
6971} ) ( ( { wid, sid, sgid, nsg } ) => {
7072 const outLen = ioLayout . $ . output . length ;
7173 const inputSize = ioLayout . $ . input . length ;
@@ -155,7 +157,8 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
155157 }
156158 input . write ( data ) ;
157159
158- const pipeline = useSubgroups && pipelines . subgroup ? pipelines . subgroup : pipelines . default ;
160+ const subgroupPipeline = useSubgroups ? pipelines . subgroup : null ;
161+ const pipeline = subgroupPipeline ?? pipelines . default ;
159162
160163 // Run the network
161164 for ( let i = 0 ; i < buffers . length ; i ++ ) {
@@ -173,7 +176,10 @@ function createNetwork(layers: [LayerData, LayerData][]): Network {
173176 boundPipeline = boundPipeline . withTimestampWrites ( descriptor ) ;
174177 }
175178
176- boundPipeline . dispatchWorkgroups ( buffers [ i ] . biases . dataType . elementCount ) ;
179+ const outputCount = buffers [ i ] . biases . dataType . elementCount ;
180+ boundPipeline . dispatchWorkgroups (
181+ subgroupPipeline ? outputCount : Math . ceil ( outputCount / WORKGROUP_SIZE ) ,
182+ ) ;
177183 }
178184
179185 if ( querySet ?. available ) {
0 commit comments