Skip to content

Commit 72c6619

Browse files
committed
Introduce WorkerGridFactory for standardized worker grid creation
- Added utility methods for creating workers: RMSNorm, QKV Matmul, RoPE, Attention, FFN Gate+Up, and FFN Down. - Centralized worker grid logic to improve code readability and maintainability.
1 parent 1e71fa6 commit 72c6619

1 file changed

Lines changed: 92 additions & 0 deletions

File tree

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package org.beehive.gpullama3.tornadovm.layerplanner;
2+
3+
import uk.ac.manchester.tornado.api.WorkerGrid;
4+
import uk.ac.manchester.tornado.api.WorkerGrid1D;
5+
import uk.ac.manchester.tornado.api.WorkerGrid2D;
6+
7+
public class WorkerGridFactory {
8+
private static final int DEFAULT_WORK_GROUP_SIZE = 32;
9+
10+
/**
11+
* RMS Norm worker: parallel reduction across dimension
12+
*/
13+
public static WorkerGrid createRmsNormWorker(int dim, int localSize) {
14+
WorkerGrid worker = new WorkerGrid1D(dim);
15+
worker.setGlobalWork(dim, 1, 1);
16+
worker.setLocalWork(localSize, 1, 1);
17+
return worker;
18+
}
19+
20+
/**
21+
* QKV matmul worker: combined projection output
22+
*/
23+
public static WorkerGrid createQkvMatmulWorker(int opSize) {
24+
int global = opSize * DEFAULT_WORK_GROUP_SIZE;
25+
WorkerGrid worker = new WorkerGrid1D(global);
26+
worker.setLocalWork(DEFAULT_WORK_GROUP_SIZE, 1, 1);
27+
return worker;
28+
}
29+
30+
/**
31+
* RoPE worker: 2D grid for position encoding
32+
*/
33+
public static WorkerGrid createRoPEWorker(int numberOfHeads, int headSize) {
34+
int ic = headSize / 2;
35+
WorkerGrid worker = new WorkerGrid2D(numberOfHeads, ic);
36+
worker.setGlobalWork(numberOfHeads, ic, 1);
37+
worker.setLocalWork(8, 1, 1);
38+
return worker;
39+
}
40+
41+
/**
42+
* Attention worker: compute all heads in parallel
43+
*/
44+
public static WorkerGrid createAttentionWorker(int numberOfHeads, int headSize) {
45+
int optimalLocalSize = findOptimalLocalSize(headSize);
46+
WorkerGrid worker = new WorkerGrid1D(numberOfHeads);
47+
worker.setGlobalWork(numberOfHeads * optimalLocalSize, 1, 1);
48+
worker.setLocalWork(optimalLocalSize, 1, 1);
49+
return worker;
50+
}
51+
52+
/**
53+
* FFN gate+up worker: combined projection
54+
*/
55+
public static WorkerGrid createGateUpWorker(int hiddenDim) {
56+
int global = (2 * hiddenDim) * DEFAULT_WORK_GROUP_SIZE;
57+
WorkerGrid worker = new WorkerGrid1D(global);
58+
worker.setLocalWork(DEFAULT_WORK_GROUP_SIZE, 1, 1);
59+
return worker;
60+
}
61+
62+
/**
63+
* FFN down worker: final projection
64+
*/
65+
public static WorkerGrid createDownWorker(int dim) {
66+
int global = dim * DEFAULT_WORK_GROUP_SIZE;
67+
WorkerGrid worker = new WorkerGrid1D(global);
68+
worker.setLocalWork(DEFAULT_WORK_GROUP_SIZE, 1, 1);
69+
return worker;
70+
}
71+
72+
private static int findOptimalLocalSize(int size) {
73+
int optimal = Math.min(size, 64);
74+
if (size % optimal != 0) {
75+
for (int s = 64; s >= 1; s--) {
76+
if (size % s == 0) {
77+
optimal = s;
78+
break;
79+
}
80+
}
81+
}
82+
return optimal;
83+
}
84+
85+
// private static WorkerGrid createLogitVocabWorker() {
86+
// // RMSNorm operations
87+
// int vocabSizeRowMajor = config.vocabularySize() * LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS;
88+
// WorkerGrid vocabWorker = new WorkerGrid1D(vocabSizeRowMajor);
89+
// vocabWorker.setLocalWork(LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS, 1, 1);
90+
//
91+
// }
92+
}

0 commit comments

Comments
 (0)