@@ -35,34 +35,59 @@ public LogitsFP16Layer(String name, State state, Weights weights, Configuration
3535 this .logitsTaskGraph = setupLogitsTaskGraph (fp16Weights , config );
3636 }
3737
38+ private TaskGraph setupLogitNonNVidia (FP16Weights weights , Configuration config ) {
39+ TaskGraph logits = new TaskGraph ("logits" )
40+ .consumeFromDevice (lastTaskGraphID ,
41+ state .wrapX
42+ )
43+ .transferToDevice (DataTransferMode .EVERY_EXECUTION ,
44+ state .tempLogits
45+ )
46+ .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
47+ context ,
48+ state .wrapLogits ,
49+ weights .wclsHalfFloat ,
50+ weights .rms_final_weight_as_floatArray
51+ )
52+ .task ("reductionsOneBlockLogits" , TransformerComputeKernels ::reductionOneBlockWithLayer , context , state .tempLogits ,
53+ state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
54+ .task ("mapContextLogits" , TransformerComputeKernels ::reductionOneBlock2WithLogits , context , state .wrapX ,
55+ weights .rms_final_weight_as_floatArray , state .tempLogits );
56+ logits .task ("projection" , TransformerComputeKernelsLayered ::matrixVectorGeneric , //
57+ context , state .wrapX , state .wrapLogits , weights .wclsHalfFloat , //
58+ config .dim (), config .vocabularySize (), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS ); //
59+ logits .transferToHost (DataTransferMode .EVERY_EXECUTION , state .wrapLogits );
60+ return logits ;
61+ }
62+
3863 /**
3964 * Builds the logits computation graph.
4065 */
4166 private TaskGraph setupLogitsTaskGraph (FP16Weights weights , Configuration config ) {
4267
43- TaskGraph logits = new TaskGraph ("logits" )
44- .consumeFromDevice (lastTaskGraphID ,
45- state .wrapX
46- )
47- .transferToDevice (DataTransferMode .EVERY_EXECUTION ,
48- state .tempLogits
49- )
50- .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
51- context ,
52- state .wrapLogits ,
53- weights .wclsHalfFloat ,
54- weights .rms_final_weight_as_floatArray
55- )
56- .task ("reductionsOneBlockLogits" , TransformerComputeKernels ::reductionOneBlockWithLayer , context , state .tempLogits ,
57- state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
58- .task ("mapContextLogits" , TransformerComputeKernels ::reductionOneBlock2WithLogits , context , state .wrapX ,
59- weights .rms_final_weight_as_floatArray , state .tempLogits );
60- logits .task ("projection" , TransformerComputeKernelsLayered ::matrixVectorGeneric ,
61- context , state .wrapX , state .wrapLogits , weights .wclsHalfFloat ,
62- config .dim (), config .vocabularySize (), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS );
63- logits .transferToHost (DataTransferMode .EVERY_EXECUTION , state .wrapLogits );
64-
65- return logits ;
68+ TaskGraph logits = new TaskGraph ("logits" )
69+ .consumeFromDevice (lastTaskGraphID ,
70+ state .wrapX
71+ )
72+ .transferToDevice (DataTransferMode .EVERY_EXECUTION ,
73+ state .tempLogits
74+ )
75+ .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
76+ context ,
77+ state .wrapLogits ,
78+ weights .wclsHalfFloat ,
79+ weights .rms_final_weight_as_floatArray
80+ )
81+ .task ("reductionsOneBlockLogits" , TransformerComputeKernels ::reductionOneBlockWithLayer , context , state .tempLogits ,
82+ state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
83+ .task ("mapContextLogits" , TransformerComputeKernels ::reductionOneBlock2WithLogits , context , state .wrapX ,
84+ weights .rms_final_weight_as_floatArray , state .tempLogits );
85+ logits .task ("projection" , TransformerComputeKernelsLayered ::matrixVectorGeneric ,
86+ context , state .wrapX , state .wrapLogits , weights .wclsHalfFloat ,
87+ config .dim (), config .vocabularySize (), LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS );
88+ logits .transferToHost (DataTransferMode .EVERY_EXECUTION , state .wrapLogits );
89+
90+ return logits ;
6691 }
6792
6893 private GridScheduler setupGridSchedulerForLogits (Configuration config ) {
@@ -85,22 +110,42 @@ private GridScheduler setupGridSchedulerForLogits(Configuration config) {
85110 return scheduler ;
86111 }
87112
88- @ Override
89- public GridScheduler updateGridScheduler (GridScheduler scheduler ) {
90- // RMSNorm operations
91- WorkerGrid rmsNormWorker = new WorkerGrid1D (config .dim ());
92- rmsNormWorker .setGlobalWork (config .dim (), 1 , 1 );
93- rmsNormWorker .setLocalWork (256 , 1 , 1 );
94-
95- // Projection kernel (vocabulary size × hidden dim)
96- int vocabSizeGlobal = config .vocabularySize () * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS ;
97- WorkerGrid projectionWorker = new WorkerGrid1D (vocabSizeGlobal );
98- projectionWorker .setLocalWork (LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS , 1 , 1 );
113+ // @Override
114+ // public GridScheduler updateGridScheduler(GridScheduler scheduler) {
115+ // // RMSNorm operations
116+ // WorkerGrid rmsNormWorker = new WorkerGrid1D(config.dim());
117+ // rmsNormWorker.setGlobalWork(config.dim(), 1, 1);
118+ // rmsNormWorker.setLocalWork(256, 1, 1);
119+ //
120+ // // Projection kernel (vocabulary size × hidden dim)
121+ // int vocabSizeGlobal = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
122+ // WorkerGrid projectionWorker = new WorkerGrid1D(vocabSizeGlobal);
123+ // projectionWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
124+ //
125+ // scheduler.addWorkerGrid("logits.projection", projectionWorker);
126+ // scheduler.addWorkerGrid("logits.reductionsOneBlockLogits", rmsNormWorker);
127+ // scheduler.addWorkerGrid("logits.mapContextLogits", rmsNormWorker);
128+ //
129+ // return scheduler;
130+ // }
99131
100- scheduler .addWorkerGrid ("logits.projection" , projectionWorker );
101- scheduler .addWorkerGrid ("logits.reductionsOneBlockLogits" , rmsNormWorker );
102- scheduler .addWorkerGrid ("logits.mapContextLogits" , rmsNormWorker );
103132
133+ @ Override
134+ public GridScheduler updateGridScheduler (GridScheduler tornadoForwardScheduler ) {
135+ // RMSNorm operations
136+ WorkerGrid rmsNormWorker = new WorkerGrid1D (config .dim ());
137+ rmsNormWorker .setGlobalWork (config .dim (), 1 , 1 );
138+ rmsNormWorker .setLocalWork (256 , 1 , 1 );
139+
140+ // OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.vocabularySize,1,1], localWorkSize=[16,1,1])
141+ // CUDA equivalent: kernel<<<dim3((config.vocabularySize+15)/16,1,1), dim3(16,1,1)>>>
142+ int vocabSizeRowMajor = config .vocabularySize () * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS ;
143+ WorkerGrid vocabWorker = new WorkerGrid1D (vocabSizeRowMajor );
144+ vocabWorker .setLocalWork (LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS , 1 , 1 );
145+
146+ tornadoForwardScheduler .addWorkerGrid ("logits.projection" , vocabWorker );
147+ tornadoForwardScheduler .addWorkerGrid ("logits.reductionsOneBlockLogits" , rmsNormWorker );
148+ tornadoForwardScheduler .addWorkerGrid ("logits.mapContextLogits" , rmsNormWorker );
104149 return scheduler ;
105150 }
106151
0 commit comments