@@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32(
99539953 const int ith = params->ith ;
99549954 const int nth = params->nth ;
99559955
9956- if (ith >= HEADS) {
9957- return ;
9958- }
9959-
9960- const int h_start = (HEADS * ith) / nth;
9961- const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS) ?
9962- (HEADS * (ith + 1 )) / nth : HEADS;
9956+ const int h_start = (HEADS * (ith )) / nth;
9957+ const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS) ?
9958+ (HEADS * (ith + 1 )) / nth : HEADS;
99639959
99649960 float * k = (float *) dst->src [0 ]->data ;
99659961 float * v = (float *) dst->src [1 ]->data ;
@@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32(
1017010166 const int ith = params->ith ;
1017110167 const int nth = params->nth ;
1017210168
10173- if (ith >= HEADS) {
10174- return ;
10175- }
10176-
10177- const int h_start = (HEADS * ith) / nth;
10178- const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS) ?
10179- (HEADS * (ith + 1 )) / nth : HEADS;
10169+ const int h_start = (HEADS * (ith )) / nth;
10170+ const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS) ?
10171+ (HEADS * (ith + 1 )) / nth : HEADS;
1018010172
1018110173 float * k = (float *) dst->src [0 ]->data ;
1018210174 float * v = (float *) dst->src [1 ]->data ;
@@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32(
1063310625 const int ith = params->ith ;
1063410626 const int nth = params->nth ;
1063510627
10636- if (ith >= HEADS) {
10637- return ;
10638- }
10639-
10640- const int h_start = (HEADS * ith) / nth;
10641- const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS) ?
10642- (HEADS * (ith + 1 )) / nth : HEADS;
10628+ const int h_start = (HEADS * (ith )) / nth;
10629+ const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS) ?
10630+ (HEADS * (ith + 1 )) / nth : HEADS;
1064310631
1064410632 float * r = (float *) dst->src [0 ]->data ;
1064510633 float * w = (float *) dst->src [1 ]->data ;
0 commit comments