-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathmain.cpp
More file actions
1042 lines (946 loc) · 52 KB
/
main.cpp
File metadata and controls
1042 lines (946 loc) · 52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <iostream>
#include <numeric>
#include <fstream>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <queue>
#define JSON_USE_IMPLICIT_CONVERSIONS 0
#include "json.hpp"
#include "kernels.cuh"
using json = nlohmann::json;
constexpr int MAX_NEW_TOKENS_GENERATED = 20; // TODO: parameterize it with program arguments
constexpr int B_TO_MB = 1024 * 1024;
constexpr int B_TO_GB = 1024 * 1024 * 1024;
constexpr int N_LAYERS = 16; // TODO: hardcoded for llama 3.2 1B, just like any other value for now
constexpr int EMBEDDING_LENGTH = 2048;
constexpr int HIDDEN_DIM = 8192;
constexpr int KV_DIM = 512;
constexpr int HEAD_DIM = 64;
constexpr int NUM_Q_HEADS = 32;
constexpr int NUM_K_HEADS = 8;
constexpr int NUM_V_HEADS = 8;
constexpr int GQA_Q_TO_K_RATIO = 4;
constexpr int GQA_ATTN_SCORES_TO_V_RATIO = 4;
constexpr int VOCAB_SIZE = 128256;
constexpr int END_OF_TEXT_TOKEN_ID = 128001; // <|end_of_text|>
constexpr int EOT_ID_TOKEN_ID = 128009; // <|eot_id|>
constexpr int MAX_SEQ_LEN = 2048; // TODO: make it tunable
constexpr int BATCH_SIZE = 2; // TODO: not even close to being good, it's just here to have batching
constexpr int MAX_PROMPT_LEN = 512; // TODO: arbitrary, tunable
constexpr int MAX_BUFFER_SIZE = std::max(MAX_PROMPT_LEN, BATCH_SIZE);
constexpr int BLOCK_SIZE = 16; // TODO: tunable as well, defined the size of a single page in pagedattn
constexpr int V_OFFSET = BLOCK_SIZE * KV_DIM * sizeof(__nv_bfloat16);
constexpr int BLOCK_BYTES = V_OFFSET * 2; // * 2 because K and V
constexpr size_t KV_CACHE_SIZE_BYTES = 2ULL * 1024 * 1024 * 1024; // TODO: 2GB
constexpr int MAX_BLOCKS_PER_SEQ = MAX_SEQ_LEN / BLOCK_SIZE; // 2048 / 16 = 128
constexpr int NUM_BLOCKS = KV_CACHE_SIZE_BYTES / BLOCK_BYTES; // 2*1024*1024*1024/(16*512*2*2) = 65536
constexpr int MAX_SEQUENCES = BATCH_SIZE;
int checkGPUStatus()
{
int device_count = 0;
cudaGetDeviceCount(&device_count);
if (device_count == 0)
{
std::cerr << "No CUDA devices found\n";
return 1;
}
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
std::cout << "Device: " << prop.name << "\n";
std::cout << "Compute capability: " << prop.major << "." << prop.minor << "\n";
std::cout << "Global memory: " << prop.totalGlobalMem / B_TO_MB << " MB\n";
std::cout << "SM count: " << prop.multiProcessorCount << "\n";
std::cout << "Max threads per block: " << prop.maxThreadsPerBlock << std::endl;
size_t free_mem;
size_t total_mem;
cudaMemGetInfo(&free_mem, &total_mem);
std::cout << "Free memory: " << free_mem / B_TO_GB << "GB, total memory: " << total_mem / B_TO_GB << "GB\n";
return 0;
}
struct Weights
{
__nv_bfloat16 *embed_tokens;
__nv_bfloat16 *input_layernorm[N_LAYERS];
__nv_bfloat16 *mlp_gate_proj[N_LAYERS];
__nv_bfloat16 *mlp_up_proj[N_LAYERS];
__nv_bfloat16 *mlp_down_proj[N_LAYERS];
__nv_bfloat16 *post_attn_layernorms[N_LAYERS];
__nv_bfloat16 *w_k[N_LAYERS];
__nv_bfloat16 *w_o[N_LAYERS];
__nv_bfloat16 *w_q[N_LAYERS];
__nv_bfloat16 *w_v[N_LAYERS];
__nv_bfloat16 *norm;
};
int loadWeights(Weights &weights)
{
if (checkGPUStatus() != 0)
{
return 1;
}
// READ SAFETENSORS
std::ifstream safetensors_file("model.safetensors", std::ios_base::binary); // TODO: use args to provide the path or smth
if (!safetensors_file.is_open())
{
std::cout << "Can't open model.safetensors file\n";
safetensors_file.close();
return 1;
}
// READ SAFETENSORS HEADER SIZE
uint64_t header_size;
safetensors_file.read(reinterpret_cast<char *>(&header_size), 8);
// READ SAFETENSORS HEADER
std::string header;
header.resize(header_size);
safetensors_file.read(header.data(), header_size);
// READ OFFSETS OF EVERY LAYER (TENSOR) TO KNOW WHERE EVERY LAYER STARTS AND ENDS IN THE MEMORY
std::unordered_map<std::string, uint64_t> offsets;
json header_json = json::parse(header);
uint64_t max_offset = 0;
for (auto &[key, value] : header_json.items())
{
if (key == "__metadata__")
{
continue;
}
uint64_t offset_end = value["data_offsets"].at(1).get<uint64_t>();
if (offset_end > max_offset)
{
max_offset = offset_end;
}
offsets[key] = value["data_offsets"].at(0).get<uint64_t>();
}
void *model_weights;
cudaMalloc(&model_weights, max_offset); // max_offset tells where the model weights end in the memory
std::vector<char> model_weights_cpu;
model_weights_cpu.resize(max_offset);
safetensors_file.read(model_weights_cpu.data(), max_offset);
cudaMemcpy(model_weights, model_weights_cpu.data(), max_offset, cudaMemcpyHostToDevice);
safetensors_file.close();
// BASICALLY A HELPER STRUCT TO HAVE AN EASY ACCESS TO ANY MODEL WEIGHTS ON GPU
// TODO: right now I know the model structure since it's always llama 3.2 1B-Instruct, but maybe it would be convenient
// to store dimensions somewhere for even easier access?
weights.embed_tokens = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.embed_tokens.weight"));
weights.norm = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.norm.weight"));
for (int i = 0; i < N_LAYERS; ++i)
{
weights.input_layernorm[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".input_layernorm.weight"));
weights.mlp_down_proj[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".mlp.down_proj.weight"));
weights.mlp_gate_proj[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"));
weights.mlp_up_proj[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".mlp.up_proj.weight"));
weights.post_attn_layernorms[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"));
weights.w_k[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".self_attn.k_proj.weight"));
weights.w_o[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".self_attn.o_proj.weight"));
weights.w_q[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".self_attn.q_proj.weight"));
weights.w_v[i] = (__nv_bfloat16 *)((char *)model_weights + offsets.at("model.layers." + std::to_string(i) + ".self_attn.v_proj.weight"));
}
return 0;
}
// TODO: clean up this mess lol XD (I mean, the arguments list is so long, but maybe that's unavoidable, I don't know yet)
void prefill(std::vector<int> &prompt, std::queue<std::vector<int>> &queue, int &prompt_len, std::vector<bool> &is_slot_free, int slot, int *gpu_input_tokens, nv_bfloat16 *input_embeddings, Weights &weights, nv_bfloat16 *hidden_state, nv_bfloat16 *rms_norms, nv_bfloat16 *&q_proj, nv_bfloat16 *buf_2048_1, cublasHandle_t cublas_handle, float &q_proj_alpha, float &q_proj_beta, float &k_proj_alpha, float &k_proj_beta, float &v_proj_alpha, float &v_proj_beta, nv_bfloat16 *prefill_attn_scores, float &attn_alpha, float &attn_beta, nv_bfloat16 *&attn_scores_v, float &attn_scores_v_alpha, float &attn_scores_v_beta, nv_bfloat16 *&o_proj, nv_bfloat16 *buf_2048_2, float &o_proj_alpha, float &o_proj_beta, float &gate_alpha, float &gate_beta, nv_bfloat16 *gate, float &up_alpha, float &up_beta, nv_bfloat16 *up, nv_bfloat16 *&down, float &down_alpha, float &down_beta, float &embed_alpha, float &embed_beta, nv_bfloat16 *embed_proj, std::vector<nv_bfloat16> &embed_proj_cpu, std::vector<std::vector<int>> &generated_tokens, std::vector<int> &last_generated_tokens, std::vector<int> ¤t_prompt_len, __nv_bfloat16 *k_proj_temp_buf, __nv_bfloat16 *v_proj_temp_buf, std::vector<int> &block_table, int *block_table_gpu, std::vector<int> &free_blocks, __nv_bfloat16 *kv_cache)
{
prompt = queue.front();
prompt_len = prompt.size();
queue.pop();
is_slot_free[slot] = false;
cudaMemcpy(gpu_input_tokens, prompt.data(), prompt_len * sizeof(int), cudaMemcpyHostToDevice);
embeddingGather(gpu_input_tokens, input_embeddings, weights.embed_tokens, prompt_len);
cudaMemcpy(hidden_state,
input_embeddings,
prompt_len * EMBEDDING_LENGTH * sizeof(__nv_bfloat16),
cudaMemcpyDeviceToDevice);
for (int layer = 0; layer < N_LAYERS; ++layer)
{
rmsNorm(hidden_state, rms_norms, weights.input_layernorm[layer], prompt_len);
// Q = inputs * wq^T; my matrices are row-major, cublas expects column-major
// it perceives my matrices as transposed
// there's a trick where C = A * B == C^T = B^T * A^T
// so in my scenario cublas sees now: Q = inputs^T * wq^T^T = inputs ^T * wq
// so I need to do: Q^T = wq ^T * inputs
// the beauty is that we don't need to transpose Q^T back to Q
// because cublas sees the output as column-major
// so it's in fact transposed
// final dim (num_tok, EMBEDDING_LENGTH)
q_proj = buf_2048_1;
cublasStatus_t q_proj_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
EMBEDDING_LENGTH,
prompt_len,
EMBEDDING_LENGTH,
&q_proj_alpha,
weights.w_q[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&q_proj_beta,
q_proj,
CUDA_R_16BF,
EMBEDDING_LENGTH,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// input = (num_tokens, EMBEDDING_LENGTH), weights = (KV_DIM, EMBEDDING_LENGTH)
// after trick: (KV_DIM, EMBEDDING_LENGTH) * (EMBEDDING_LENGTH, num_tokens) -> (KV_DIM, num_tokens), which really is (num_tok, KV_DIM)
// lda: EMBEDDING_LENGTH, ldb: EMBEDDING_LENGTH, ldc: KV_DIM
cublasStatus_t k_proj_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
KV_DIM,
prompt_len,
EMBEDDING_LENGTH,
&k_proj_alpha,
weights.w_k[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&k_proj_beta,
k_proj_temp_buf,
CUDA_R_16BF,
KV_DIM,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// same as K projection
cublasStatus_t v_proj_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
KV_DIM,
prompt_len,
EMBEDDING_LENGTH,
&v_proj_alpha,
weights.w_v[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&v_proj_beta,
v_proj_temp_buf,
CUDA_R_16BF,
KV_DIM,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// RoPE now
rope(q_proj, prompt_len, EMBEDDING_LENGTH);
rope(k_proj_temp_buf, prompt_len, KV_DIM);
// PagedAttention - scatter K and V into blocks
// slot - index within batch
// layer - index of layer
// ceil(prompt_len/BLOCK_SIZE) = number of blocks needed to allocate in block table
for (int token_idx = 0; token_idx < prompt_len; token_idx += BLOCK_SIZE)
{
int num_tokens_to_copy = prompt_len - token_idx;
if (num_tokens_to_copy > BLOCK_SIZE)
{
num_tokens_to_copy = BLOCK_SIZE;
}
// read index of physical block from logical block_table
// if -1, then need to allocate the new block
// pop from free_blocks
// write its value to block_table on the same position we read from
// compute address of this block table in kv_cache
// write tokens to it
int block_idx = token_idx / BLOCK_SIZE;
int block = block_table[slot * N_LAYERS * MAX_BLOCKS_PER_SEQ + layer * MAX_BLOCKS_PER_SEQ + block_idx];
if (block == -1)
{
int physical_block_idx = free_blocks.back();
free_blocks.pop_back();
block = physical_block_idx;
block_table[slot * N_LAYERS * MAX_BLOCKS_PER_SEQ + layer * MAX_BLOCKS_PER_SEQ + block_idx] = block;
}
else
{
assert(false && "block must be -1 during prefill - what happened?");
// probably in prefill this doesn't make a lot of sense? but will matter in decode
}
// store K
__nv_bfloat16 *k_cache_ptr = (__nv_bfloat16 *)((char *)kv_cache + block * BLOCK_BYTES);
__nv_bfloat16 *k_proj_ptr = k_proj_temp_buf + token_idx * KV_DIM;
cudaMemcpy(k_cache_ptr, k_proj_ptr, num_tokens_to_copy * KV_DIM * sizeof(__nv_bfloat16), cudaMemcpyDeviceToDevice);
// store V
__nv_bfloat16 *v_cache_ptr = (__nv_bfloat16 *)((char *)kv_cache + block * BLOCK_BYTES + V_OFFSET);
__nv_bfloat16 *v_proj_ptr = v_proj_temp_buf + token_idx * KV_DIM;
cudaMemcpy(v_cache_ptr, v_proj_ptr, num_tokens_to_copy * KV_DIM * sizeof(__nv_bfloat16), cudaMemcpyDeviceToDevice);
}
// attention scores
// per head, 64 elements each
// so total 32 heads
// Q (num_tok, 2048)
// K (num_tok, 512)
// GQA grouping reuses 1 K head per 4 consecutive Q heads
// Q_head (num_tok, 64)
// K_head (num_tok, 64)
// attn_score_head = Q_head * K_head^T / sqrt(64)
// so: head output dims (num_tok, num_tok)
// total output (32, num_tok, num_tok)
for (int i = 0; i < NUM_Q_HEADS; ++i)
{
int k_head_idx = i / GQA_Q_TO_K_RATIO;
__nv_bfloat16 *q_head = q_proj + i * HEAD_DIM;
__nv_bfloat16 *k_head = k_proj_temp_buf + k_head_idx * HEAD_DIM;
__nv_bfloat16 *attn_score_head = prefill_attn_scores + prompt_len * prompt_len * i;
cublasStatus_t attn_score_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
prompt_len,
prompt_len,
HEAD_DIM,
&attn_alpha,
k_head,
CUDA_R_16BF,
KV_DIM,
q_head,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&attn_beta,
attn_score_head,
CUDA_R_16BF,
prompt_len,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
}
causalMask(prefill_attn_scores, prompt_len);
softmax(prefill_attn_scores, prompt_len);
// attn scores * V
// (32, num_tok, num_tok) * (num_tok, 512)
// GQA - 4 Q heads share 1 V head
// attn_scores dim (32, num_tok, num_tok)
// attn_scores head dim (num_tok, num_tok)
// V dim (num_tok, 512)
// NUM_V_HEADS is 8 -> 512 / 8 = 64
// V_head dim (num_tok, 64)
// output head dim: scores head * V head -> (num_tok, num_tok) * (num_tok, 64) = (num_tok, 64)
// in total 32 output heads: so (num_tok, 64 * 32) = (num_tok, 2048)
attn_scores_v = buf_2048_1;
for (int i = 0; i < NUM_Q_HEADS; ++i)
{
int v_head_idx = i / GQA_ATTN_SCORES_TO_V_RATIO;
// i * prompt_under_prefill.size() * prompt_under_prefill.size(), because attn scores is (32, num_tok, num_tok)
__nv_bfloat16 *attn_scores_head = prefill_attn_scores + i * prompt_len * prompt_len;
__nv_bfloat16 *v_head = v_proj_temp_buf + v_head_idx * HEAD_DIM;
__nv_bfloat16 *output_attn_scores_head = attn_scores_v + i * HEAD_DIM;
cublasStatus_t attn_score_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
HEAD_DIM,
prompt_len,
prompt_len,
&attn_scores_v_alpha,
v_head,
CUDA_R_16BF,
KV_DIM,
attn_scores_head,
CUDA_R_16BF,
prompt_len,
&attn_scores_v_beta,
output_attn_scores_head,
CUDA_R_16BF,
EMBEDDING_LENGTH,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
}
// output projection, it will be an input for MLP blocks
// attn_scores_v * w_o^T
// (num_tok, 2048) * (2048, 2048) -> (num_tok, 2048)
// same as Q projection, so copy paste
o_proj = buf_2048_2;
cublasStatus_t o_proj_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
EMBEDDING_LENGTH,
prompt_len,
EMBEDDING_LENGTH,
&o_proj_alpha,
weights.w_o[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
attn_scores_v,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&o_proj_beta,
o_proj,
CUDA_R_16BF,
EMBEDDING_LENGTH,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// (num_tok, 2048) + (num_tok, 2048) -> (num_tok, 2048)
residualAdd(hidden_state, o_proj, prompt_len);
// post attention RMS Norm
rmsNorm(hidden_state, rms_norms, weights.post_attn_layernorms[layer], prompt_len);
// SwiGLU time - just MLP + SiLU
// gate = hidden_state (rms-normed) * mlp_gate_proj ^ T
// HIDDEN_DIM = 8192
// (num_tok, 2048) * (2048, 8192) -> (num_tok, 8192)
// my data is row major so transpose trick
// gate ^T = (mlp_gate_proj ^ T)^T * hidden_state^T
// gate ^T = mlp_gate_proj * hidden_state^T
// (num_tok, 8192)^T = (8192, 2048) * (2048, num_tok)
// but data is perceived as column major so I need to transpose mlp_gate_proj
// to make it work
// m 8192 n num_tok k 2048 lda 2048 ldb 2048 ldc 8192
cublasStatus_t gate_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
HIDDEN_DIM,
prompt_len,
EMBEDDING_LENGTH,
&gate_alpha,
weights.mlp_gate_proj[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&gate_beta,
gate,
CUDA_R_16BF,
HIDDEN_DIM,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// up, the same dims as gate
cublasStatus_t up_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
HIDDEN_DIM,
prompt_len,
EMBEDDING_LENGTH,
&up_alpha,
weights.mlp_up_proj[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&up_beta,
up,
CUDA_R_16BF,
HIDDEN_DIM,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// SiLU
// after_silu = SiLU(gate) * up (element-wise multication)
// after_silu = gate * (1 / (1 + e^(-gate))) * up
// gate is dim (num_tok, 8192), up too
silu(gate, up, prompt_len); // gate = after_silu now
// down projection
// output = post-silu * down_proj^T
// dims: (num_tok, 8192) * (2048, 8192) ^ T = (num_tok, 8192) * (8192, 2048) = (num_tok, 2048)
// output^T = (down_proj^T)^T * post-silu^T
// output^T = down_proj * post-silu^T
// cublas sees them already as transposed so only down_proj I need to transpose
// dims = (2048, 8192) * (8192, num_tok) = (2048, num_tok)
// m: 2048 n: num_tok, k: 8192
// lda: 8192, ldb: 8192, ldc: 2048
down = buf_2048_2;
cublasStatus_t down_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
EMBEDDING_LENGTH,
prompt_len,
HIDDEN_DIM,
&down_alpha,
weights.mlp_down_proj[layer],
CUDA_R_16BF,
HIDDEN_DIM,
gate,
CUDA_R_16BF,
HIDDEN_DIM,
&down_beta,
down,
CUDA_R_16BF,
EMBEDDING_LENGTH,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// (num_tok, 2048) + (num_tok, 2048) -> (num_tok, 2048)
residualAdd(hidden_state, down, prompt_len);
}
rmsNorm(hidden_state, rms_norms, weights.norm, prompt_len);
// logits = rms_norms * weights.embed_tokens^T
// dim rms_norms: (num_tok, 2048), dim embed_tokens: (128256, 2048)
// logits dim = (num_tok, 2048) * (2048, 128256) = (num_tok, 128256) => m = num_tok, n = 128256, k = 2048
// I leave this comment above because it shows a bug in my thinking
// because I use the cublas trick, logits are transposed so m and n should be swapped
// so m 128256, n num_tok
// data is row major so we treat it as transposed and use the trick
// logits^T = ((weights.embed_tokens^T)^T * rms_norms^T
// logits^T = weights.embed_tokens * rms_norms^T
// so we need to transpose embed_tokens, because rms_norms already
// appears to cublas as transposed
// lda = 2048, ldb = 2048, ldc = 128256
cublasStatus_t embed_status = cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
VOCAB_SIZE,
prompt_len,
EMBEDDING_LENGTH,
&embed_alpha,
weights.embed_tokens,
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&embed_beta,
embed_proj,
CUDA_R_16BF,
VOCAB_SIZE,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
cudaMemcpy(embed_proj_cpu.data(), embed_proj, sizeof(__nv_bfloat16) * prompt_len * VOCAB_SIZE, cudaMemcpyDeviceToHost);
// argmax to get the output token
// TODO: write a proper kernel for it
// for now just a simple CPU function
int last_token_offset = (prompt_len - 1) * VOCAB_SIZE;
float max_token = (float)embed_proj_cpu[last_token_offset];
int max_token_idx = 0;
for (int token_idx = 0; token_idx < VOCAB_SIZE; ++token_idx)
{
if ((float)embed_proj_cpu[token_idx + last_token_offset] > max_token)
{
max_token = embed_proj_cpu[token_idx + last_token_offset];
max_token_idx = token_idx;
}
}
std::cout << "Output token: " << (float)max_token << ", token index: " << std::to_string(max_token_idx) << std::endl;
generated_tokens[slot].push_back(max_token_idx);
last_generated_tokens[slot] = max_token_idx;
current_prompt_len[slot] = prompt_len;
// synchronize state of block_table with block_table_gpu
// TODO: do it more clever and not copy full table unnecessarily
cudaMemcpy(block_table_gpu, block_table.data(), MAX_SEQUENCES * N_LAYERS * MAX_BLOCKS_PER_SEQ * sizeof(int), cudaMemcpyHostToDevice);
}
int main(int argc, char *argv[])
{
cublasHandle_t cublas_handle;
cublasStatus_t status = cublasCreate(&cublas_handle);
if (status != CUBLAS_STATUS_SUCCESS)
{
std::cerr << "cuBLAS init failed, status: " << status << "\n";
return 1;
}
Weights weights{};
if (loadWeights(weights) != 0)
{
return 1;
}
// allocator for pagedattn
__nv_bfloat16 *kv_cache;
cudaMalloc(&kv_cache, KV_CACHE_SIZE_BYTES);
std::vector<int> free_blocks(NUM_BLOCKS);
std::iota(free_blocks.begin(), free_blocks.end(), 0);
std::vector<int> block_table(MAX_SEQUENCES * N_LAYERS * MAX_BLOCKS_PER_SEQ, -1);
int *block_table_gpu;
cudaMalloc(&block_table_gpu, MAX_SEQUENCES * N_LAYERS * MAX_BLOCKS_PER_SEQ * sizeof(int));
// PROMPT 0 (What is 2+2?) - length 17
std::queue<std::vector<int>> queue;
queue.push({128000, 128006, 882, 128007, 271, 3923, 374, 220, 17, 10, 17, 30, 128009, 128006, 78191, 128007, 271});
// PROMPT 1 (Name a color.) - length 14
queue.push({128000, 128006, 882, 128007, 271, 678, 264, 1933, 13, 128009, 128006, 78191, 128007, 271});
// PROMPT 2 (Say hello.) - length 13
queue.push({128000, 128006, 882, 128007, 271, 46864, 24748, 13, 128009, 128006, 78191, 128007, 271});
// PROMPT 3 (Capital of France?) - length 14
queue.push({128000, 128006, 882, 128007, 271, 64693, 315, 9822, 30, 128009, 128006, 78191, 128007, 271});
// BATCH
std::vector<bool> is_slot_free(BATCH_SIZE, true); // set to false when slot taken, set to true when free
std::vector<std::vector<int>> generated_tokens(BATCH_SIZE);
std::vector<int> last_generated_tokens(BATCH_SIZE);
std::vector<int> current_prompt_len(BATCH_SIZE, 0);
// needed to provide contiguous data for decode
std::vector<int> active_slots;
std::vector<int> active_tokens;
int *gpu_active_slots;
cudaMalloc(&gpu_active_slots, BATCH_SIZE * sizeof(int));
int *gpu_seq_lens;
cudaMalloc(&gpu_seq_lens, BATCH_SIZE * sizeof(int));
// TODO: recalculate input_tokens_size and prompt_lengths always when there is a change to prompt_under_prefill
// TODO: right now I handle input manually, it's the least interesting part, will come back to it when continuous batching and pagedattn works
std::vector<int> prompt;
int prompt_len;
int *gpu_input_tokens;
cudaMalloc(&gpu_input_tokens, MAX_PROMPT_LEN * sizeof(int));
__nv_bfloat16 *input_embeddings;
cudaMalloc(&input_embeddings, MAX_PROMPT_LEN * sizeof(__nv_bfloat16) * EMBEDDING_LENGTH);
__nv_bfloat16 *hidden_state;
cudaMalloc(&hidden_state, MAX_BUFFER_SIZE * sizeof(__nv_bfloat16) * EMBEDDING_LENGTH);
__nv_bfloat16 *rms_norms;
cudaMalloc(&rms_norms, MAX_BUFFER_SIZE * sizeof(__nv_bfloat16) * EMBEDDING_LENGTH);
__nv_bfloat16 *buf_2048_1; // shared between q_proj and attn_scores_v
cudaMalloc(&buf_2048_1, MAX_BUFFER_SIZE * sizeof(__nv_bfloat16) * EMBEDDING_LENGTH);
__nv_bfloat16 *q_proj;
float q_proj_alpha = 1.0f;
float q_proj_beta = 0.0f;
// K and V cache
__nv_bfloat16 *k_proj_temp_buf;
cudaMalloc(&k_proj_temp_buf, MAX_PROMPT_LEN * KV_DIM * sizeof(__nv_bfloat16));
__nv_bfloat16 *v_proj_temp_buf;
cudaMalloc(&v_proj_temp_buf, MAX_PROMPT_LEN * KV_DIM * sizeof(__nv_bfloat16));
float k_proj_alpha = 1.0f;
float k_proj_beta = 0.0f;
float v_proj_alpha = 1.0f;
float v_proj_beta = 0.0f;
__nv_bfloat16 *prefill_attn_scores;
cudaMalloc(&prefill_attn_scores, MAX_PROMPT_LEN * MAX_PROMPT_LEN * sizeof(__nv_bfloat16) * NUM_Q_HEADS);
float attn_alpha = 1.0f / 8.0f;
float attn_beta = 0.0f;
__nv_bfloat16 *attn_scores_v;
float attn_scores_v_alpha = 1.0f;
float attn_scores_v_beta = 0.0f;
__nv_bfloat16 *buf_2048_2; // shared between o_proj and down
cudaMalloc(&buf_2048_2, MAX_BUFFER_SIZE * sizeof(__nv_bfloat16) * EMBEDDING_LENGTH);
__nv_bfloat16 *o_proj;
float o_proj_alpha = 1.0f;
float o_proj_beta = 0.0f;
__nv_bfloat16 *gate;
cudaMalloc(&gate, MAX_BUFFER_SIZE * sizeof(__nv_bfloat16) * HIDDEN_DIM);
float gate_alpha = 1.0f;
float gate_beta = 0.0f;
__nv_bfloat16 *up;
cudaMalloc(&up, MAX_BUFFER_SIZE * sizeof(__nv_bfloat16) * HIDDEN_DIM);
float up_alpha = 1.0f;
float up_beta = 0.0f;
__nv_bfloat16 *down;
float down_alpha = 1.0f;
float down_beta = 0.0f;
__nv_bfloat16 *embed_proj;
cudaMalloc(&embed_proj, sizeof(__nv_bfloat16) * MAX_BUFFER_SIZE * VOCAB_SIZE);
float embed_alpha = 1.0f;
float embed_beta = 0.0f;
std::vector<__nv_bfloat16> embed_proj_cpu;
embed_proj_cpu.resize(MAX_BUFFER_SIZE * VOCAB_SIZE);
// decode-only allocation
int *gpu_last_tokens;
cudaMalloc(&gpu_last_tokens, BATCH_SIZE * sizeof(int));
// TODO: move argmax to GPU and get rid of these CPU<->GPU tokens moves
// reused temporary buffers for K and V cache computation during decode
__nv_bfloat16 *k_proj_batched_buffer;
cudaMalloc(&k_proj_batched_buffer, BATCH_SIZE * sizeof(__nv_bfloat16) * KV_DIM);
__nv_bfloat16 *v_proj_batched_buffer;
cudaMalloc(&v_proj_batched_buffer, BATCH_SIZE * sizeof(__nv_bfloat16) * KV_DIM);
for (int slot = 0; slot < is_slot_free.size() && !queue.empty(); ++slot)
{
if (!is_slot_free[slot])
{
continue; // slot taken, skip
}
prefill(prompt, queue, prompt_len, is_slot_free, slot, gpu_input_tokens, input_embeddings, weights, hidden_state, rms_norms, q_proj, buf_2048_1, cublas_handle, q_proj_alpha, q_proj_beta, k_proj_alpha, k_proj_beta, v_proj_alpha, v_proj_beta, prefill_attn_scores, attn_alpha, attn_beta, attn_scores_v, attn_scores_v_alpha, attn_scores_v_beta, o_proj, buf_2048_2, o_proj_alpha, o_proj_beta, gate_alpha, gate_beta, gate, up_alpha, up_beta, up, down, down_alpha, down_beta, embed_alpha, embed_beta, embed_proj, embed_proj_cpu, generated_tokens, last_generated_tokens, current_prompt_len, k_proj_temp_buf, v_proj_temp_buf, block_table, block_table_gpu, free_blocks, kv_cache);
// // after prefill:
// int first_token = -1; // TODO just a stub
// last_generated_tokens[slot] = first_token;
// current_prompt_len[slot] = prompt.size();
// generated_tokens[slot].push_back(first_token);
}
// INFERENCE STARTS HERE! =]
// I have the same amount of embeddings as input tokens
// it's just every embedding is EMBEDDING_LENGTH length bf16 vector
// retrieved from model weights based on token's value
// PREFILL
// DECODE
// since now I operate always on index 0 for all values and for current_position_token for new K and V
while (true) // exit condition irrelevant for now, since it's an inference server that's supposed to run foreveeer!!!
{
active_slots.clear();
active_tokens.clear();
for (int slot = 0; slot < BATCH_SIZE; ++slot)
{
if (is_slot_free[slot])
{
if (queue.empty())
{
continue;
}
generated_tokens[slot].clear();
prefill(prompt, queue, prompt_len, is_slot_free, slot, gpu_input_tokens, input_embeddings, weights, hidden_state, rms_norms, q_proj, buf_2048_1, cublas_handle, q_proj_alpha, q_proj_beta, k_proj_alpha, k_proj_beta, v_proj_alpha, v_proj_beta, prefill_attn_scores, attn_alpha, attn_beta, attn_scores_v, attn_scores_v_alpha, attn_scores_v_beta, o_proj, buf_2048_2, o_proj_alpha, o_proj_beta, gate_alpha, gate_beta, gate, up_alpha, up_beta, up, down, down_alpha, down_beta, embed_alpha, embed_beta, embed_proj, embed_proj_cpu, generated_tokens, last_generated_tokens, current_prompt_len, k_proj_temp_buf, v_proj_temp_buf, block_table, block_table_gpu, free_blocks, kv_cache);
}
active_slots.push_back(slot);
active_tokens.push_back(last_generated_tokens[slot]);
}
int num_active_slots = active_slots.size();
if (num_active_slots == 0)
{
if (queue.empty())
{
break; // TODO: continue will make sense when I will finally write to queue, for now it has predefined size so break instead
}
continue;
}
// copy useful data to gpu
cudaMemcpy(gpu_last_tokens, active_tokens.data(), num_active_slots * sizeof(int), cudaMemcpyHostToDevice);
cudaMemcpy(gpu_active_slots, active_slots.data(), num_active_slots * sizeof(int), cudaMemcpyHostToDevice);
std::vector<int> seq_lens(num_active_slots);
for (int slot = 0; slot < num_active_slots; ++slot)
{
int active_slot = active_slots[slot];
seq_lens[slot] = current_prompt_len[active_slot] + 1;
}
cudaMemcpy(gpu_seq_lens, seq_lens.data(), seq_lens.size() * sizeof(int), cudaMemcpyHostToDevice);
embeddingGatherDecode(gpu_last_tokens, num_active_slots, hidden_state, weights.embed_tokens);
for (int layer = 0; layer < N_LAYERS; ++layer)
{
rmsNorm(hidden_state, rms_norms, weights.input_layernorm[layer], num_active_slots);
q_proj = buf_2048_1;
// q proj (num_prompts, 2048)
cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
EMBEDDING_LENGTH, // m
num_active_slots, // n
EMBEDDING_LENGTH, // k
&q_proj_alpha,
weights.w_q[layer], // A
CUDA_R_16BF,
EMBEDDING_LENGTH, // lda
rms_norms, // B
CUDA_R_16BF,
EMBEDDING_LENGTH, // ldb
&q_proj_beta,
q_proj, // C
CUDA_R_16BF,
EMBEDDING_LENGTH, // ldc
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// k proj (1, 512), writing output to next position in current layer's K cache
// K proj = rms_norms (num_prompt, 2048) * W_k (512, 2048)
// W_k is actually stored as 512, 2048 (out features, in features)
// so that's why we need to transpose it
// all the data is stored in row major and cublas reads it as column major
// so all the data appears as transposed
// so data actually apppears as (2048, num_prompt) * (2048, 512)
// the output of matmul will also be produced as transposed, so we can say that
// in our mental model we talk about K_proj^T
// and to get K_proj^T we can do transposition trick and write the cublas call as
// W_k^T * rms_nroms
// so we end up with: K_proj^T = W_k^T (512, 2048) * rms_norms (2048, num_prompt)
// result dim is K_proj^T = (512, num_prompt)
// but it's transposed, so in fact we get correct output dimension (num_prompt, 512)
// it was great for num_prompt=1, but the problem is that prompts have different length
// that's why we have vector of current_prompt_len, but also we can't write to K_proj
// directly, so I write to temp buffer kv_proj_batched_buffer and the scatter
// output to K_proj in a loop
cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
KV_DIM, // m = 512
num_active_slots, // n = num prompts
EMBEDDING_LENGTH, // k = 2048
&k_proj_alpha,
weights.w_k[layer], // A
CUDA_R_16BF,
EMBEDDING_LENGTH, // lda 2048, because W_k is in memory as 512, 2048
// so the gap between subsequent elements is 2048
rms_norms, // B
CUDA_R_16BF,
EMBEDDING_LENGTH, // ldb, same reason for rms_norms
&k_proj_beta,
k_proj_batched_buffer, // TODO C
CUDA_R_16BF,
KV_DIM, // ldc = 512
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// same
cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
KV_DIM,
num_active_slots,
EMBEDDING_LENGTH,
&v_proj_alpha,
weights.w_v[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&v_proj_beta,
v_proj_batched_buffer,
CUDA_R_16BF,
KV_DIM,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
for (int slot = 0; slot < num_active_slots; ++slot)
{
int active_slot = active_slots[slot];
ropeDecode(&q_proj[slot * EMBEDDING_LENGTH], current_prompt_len[active_slot], EMBEDDING_LENGTH);
ropeDecode(k_proj_batched_buffer + slot * KV_DIM, current_prompt_len[active_slot], KV_DIM);
}
// PagedAttn scatter k and v from a temp buffer, like in the prefill
for (int slot = 0; slot < num_active_slots; ++slot)
{
int active_slot = active_slots[slot];
int seq_len = current_prompt_len[active_slot]; // + generated tokens?
int logical_block_idx = seq_len / BLOCK_SIZE;
int token_in_block_idx = seq_len % BLOCK_SIZE;
int block = block_table[active_slot * N_LAYERS * MAX_BLOCKS_PER_SEQ + layer * MAX_BLOCKS_PER_SEQ + logical_block_idx];
if (token_in_block_idx == 0)
{
int physical_block_idx = free_blocks.back();
free_blocks.pop_back();
block = physical_block_idx;
block_table[active_slot * N_LAYERS * MAX_BLOCKS_PER_SEQ + layer * MAX_BLOCKS_PER_SEQ + logical_block_idx] = block;
}
__nv_bfloat16 *k_cache_ptr = (__nv_bfloat16 *)((char *)kv_cache + block * BLOCK_BYTES + token_in_block_idx * KV_DIM * sizeof(__nv_bfloat16));
__nv_bfloat16 *k_proj_ptr = k_proj_batched_buffer + slot * KV_DIM;
cudaMemcpy(k_cache_ptr, k_proj_ptr, KV_DIM * sizeof(__nv_bfloat16), cudaMemcpyDeviceToDevice);
__nv_bfloat16 *v_cache_ptr = (__nv_bfloat16 *)((char *)kv_cache + block * BLOCK_BYTES + V_OFFSET + token_in_block_idx * KV_DIM * sizeof(__nv_bfloat16));
__nv_bfloat16 *v_proj_ptr = v_proj_batched_buffer + slot * KV_DIM;
cudaMemcpy(v_cache_ptr, v_proj_ptr, KV_DIM * sizeof(__nv_bfloat16), cudaMemcpyDeviceToDevice);
}
// synchronize block table on cpu with block table on gpu (for attention)
cudaMemcpy(block_table_gpu, block_table.data(), MAX_SEQUENCES * N_LAYERS * MAX_BLOCKS_PER_SEQ * sizeof(int), cudaMemcpyHostToDevice);
pagedAttention(layer, num_active_slots, q_proj, kv_cache, block_table_gpu, gpu_seq_lens, gpu_active_slots, buf_2048_1);
o_proj = buf_2048_2;
// (1, 2048) * (2048, 2048) -> (1, 2048)
cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
EMBEDDING_LENGTH, // m
num_active_slots, // n
EMBEDDING_LENGTH, // k
&o_proj_alpha,
weights.w_o[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
buf_2048_1,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&o_proj_beta,
o_proj,
CUDA_R_16BF,
EMBEDDING_LENGTH,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
residualAdd(hidden_state, o_proj, num_active_slots);
rmsNorm(hidden_state, rms_norms, weights.post_attn_layernorms[layer], num_active_slots);
// MLP
cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
HIDDEN_DIM, // m
num_active_slots, // n
EMBEDDING_LENGTH, // k
&gate_alpha,
weights.mlp_gate_proj[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&gate_beta,
gate,
CUDA_R_16BF,
HIDDEN_DIM,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
// (1, 2048) * (2048, 8192) -> (1, 8192)
cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
HIDDEN_DIM, // m
num_active_slots, // n
EMBEDDING_LENGTH, // k
&up_alpha,
weights.mlp_up_proj[layer],
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&up_beta,
up,
CUDA_R_16BF,
HIDDEN_DIM,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
silu(gate, up, num_active_slots);
down = buf_2048_2;
cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
EMBEDDING_LENGTH, // m
num_active_slots, // n
HIDDEN_DIM, // k
&down_alpha,
weights.mlp_down_proj[layer],
CUDA_R_16BF,
HIDDEN_DIM,
gate,
CUDA_R_16BF,
HIDDEN_DIM,
&down_beta,
down,
CUDA_R_16BF,
EMBEDDING_LENGTH,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
residualAdd(hidden_state, down, num_active_slots);
}
rmsNorm(hidden_state, rms_norms, weights.norm, num_active_slots);
cublasGemmEx(cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
VOCAB_SIZE, // m
num_active_slots, // n
EMBEDDING_LENGTH, // k
&embed_alpha,
weights.embed_tokens,
CUDA_R_16BF,
EMBEDDING_LENGTH,
rms_norms,
CUDA_R_16BF,
EMBEDDING_LENGTH,
&embed_beta,
embed_proj,
CUDA_R_16BF,
VOCAB_SIZE,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);
cudaMemcpy(embed_proj_cpu.data(), embed_proj, sizeof(__nv_bfloat16) * num_active_slots * VOCAB_SIZE, cudaMemcpyDeviceToHost);
float max_token = 0.0;
int max_token_idx = 0;
for (int slot = 0; slot < num_active_slots; ++slot)
{
int active_slot = active_slots[slot];