|
12 | 12 | #include <cfloat> |
13 | 13 | #include <cmath> |
14 | 14 |
|
15 | | -extern "C" GGML_API int turbo3_cpu_wht_group_size; |
16 | | - |
17 | 15 | // ggml_compute_forward_dup |
18 | 16 |
|
19 | 17 | static void ggml_compute_forward_dup_same_cont( |
@@ -682,6 +680,7 @@ void ggml_compute_forward_add( |
682 | 680 | case GGML_TYPE_TQ2_0: |
683 | 681 | case GGML_TYPE_TQ3_1S: |
684 | 682 | case GGML_TYPE_TQ4_1S: |
| 683 | + case GGML_TYPE_TQ4_0: |
685 | 684 | case GGML_TYPE_IQ2_XXS: |
686 | 685 | case GGML_TYPE_IQ2_XS: |
687 | 686 | case GGML_TYPE_IQ3_XXS: |
@@ -1134,6 +1133,7 @@ void ggml_compute_forward_add1( |
1134 | 1133 | case GGML_TYPE_TQ2_0: |
1135 | 1134 | case GGML_TYPE_TQ3_1S: |
1136 | 1135 | case GGML_TYPE_TQ4_1S: |
| 1136 | + case GGML_TYPE_TQ4_0: |
1137 | 1137 | case GGML_TYPE_IQ2_XXS: |
1138 | 1138 | case GGML_TYPE_IQ2_XS: |
1139 | 1139 | case GGML_TYPE_IQ3_XXS: |
@@ -1265,6 +1265,7 @@ void ggml_compute_forward_acc( |
1265 | 1265 | case GGML_TYPE_TQ2_0: |
1266 | 1266 | case GGML_TYPE_TQ3_1S: |
1267 | 1267 | case GGML_TYPE_TQ4_1S: |
| 1268 | + case GGML_TYPE_TQ4_0: |
1268 | 1269 | case GGML_TYPE_IQ2_XXS: |
1269 | 1270 | case GGML_TYPE_IQ2_XS: |
1270 | 1271 | case GGML_TYPE_IQ3_XXS: |
@@ -4355,6 +4356,7 @@ void ggml_compute_forward_out_prod( |
4355 | 4356 | case GGML_TYPE_TQ2_0: |
4356 | 4357 | case GGML_TYPE_TQ3_1S: |
4357 | 4358 | case GGML_TYPE_TQ4_1S: |
| 4359 | + case GGML_TYPE_TQ4_0: |
4358 | 4360 | case GGML_TYPE_IQ2_XXS: |
4359 | 4361 | case GGML_TYPE_IQ2_XS: |
4360 | 4362 | case GGML_TYPE_IQ3_XXS: |
@@ -4633,6 +4635,7 @@ void ggml_compute_forward_set( |
4633 | 4635 | case GGML_TYPE_TQ2_0: |
4634 | 4636 | case GGML_TYPE_TQ3_1S: |
4635 | 4637 | case GGML_TYPE_TQ4_1S: |
| 4638 | + case GGML_TYPE_TQ4_0: |
4636 | 4639 | case GGML_TYPE_IQ2_XXS: |
4637 | 4640 | case GGML_TYPE_IQ2_XS: |
4638 | 4641 | case GGML_TYPE_IQ3_XXS: |
@@ -4858,6 +4861,7 @@ void ggml_compute_forward_get_rows( |
4858 | 4861 | case GGML_TYPE_TQ2_0: |
4859 | 4862 | case GGML_TYPE_TQ3_1S: |
4860 | 4863 | case GGML_TYPE_TQ4_1S: |
| 4864 | + case GGML_TYPE_TQ4_0: |
4861 | 4865 | case GGML_TYPE_IQ2_XXS: |
4862 | 4866 | case GGML_TYPE_IQ2_XS: |
4863 | 4867 | case GGML_TYPE_IQ3_XXS: |
@@ -4942,6 +4946,7 @@ static void ggml_compute_forward_set_rows_f32( |
4942 | 4946 |
|
4943 | 4947 | // For turbo types: communicate WHT group size to the quantize function via global |
4944 | 4948 | if (dst->type == GGML_TYPE_TURBO3_0 || dst->type == GGML_TYPE_TURBO4_0 || dst->type == GGML_TYPE_TURBO2_0) { |
| 4949 | + extern int turbo3_cpu_wht_group_size; |
4945 | 4950 | int gs = 0; |
4946 | 4951 | memcpy(&gs, dst->op_params, sizeof(int)); |
4947 | 4952 | turbo3_cpu_wht_group_size = (gs == 64 || gs == 128) ? gs : 0; |
@@ -5592,6 +5597,7 @@ void ggml_compute_forward_clamp( |
5592 | 5597 | case GGML_TYPE_TQ2_0: |
5593 | 5598 | case GGML_TYPE_TQ3_1S: |
5594 | 5599 | case GGML_TYPE_TQ4_1S: |
| 5600 | + case GGML_TYPE_TQ4_0: |
5595 | 5601 | case GGML_TYPE_IQ2_XXS: |
5596 | 5602 | case GGML_TYPE_IQ2_XS: |
5597 | 5603 | case GGML_TYPE_IQ3_XXS: |
@@ -9976,9 +9982,13 @@ static void ggml_compute_forward_rwkv_wkv6_f32( |
9976 | 9982 | const int ith = params->ith; |
9977 | 9983 | const int nth = params->nth; |
9978 | 9984 |
|
9979 | | - const int h_start = (HEADS * (ith )) / nth; |
9980 | | - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? |
9981 | | - (HEADS * (ith + 1)) / nth : HEADS; |
| 9985 | + if (ith >= HEADS) { |
| 9986 | + return; |
| 9987 | + } |
| 9988 | + |
| 9989 | + const int h_start = (HEADS * ith) / nth; |
| 9990 | + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? |
| 9991 | + (HEADS * (ith + 1)) / nth : HEADS; |
9982 | 9992 |
|
9983 | 9993 | float * k = (float *) dst->src[0]->data; |
9984 | 9994 | float * v = (float *) dst->src[1]->data; |
@@ -10189,9 +10199,13 @@ static void ggml_compute_forward_gla_f32( |
10189 | 10199 | const int ith = params->ith; |
10190 | 10200 | const int nth = params->nth; |
10191 | 10201 |
|
10192 | | - const int h_start = (HEADS * (ith )) / nth; |
10193 | | - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? |
10194 | | - (HEADS * (ith + 1)) / nth : HEADS; |
| 10202 | + if (ith >= HEADS) { |
| 10203 | + return; |
| 10204 | + } |
| 10205 | + |
| 10206 | + const int h_start = (HEADS * ith) / nth; |
| 10207 | + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? |
| 10208 | + (HEADS * (ith + 1)) / nth : HEADS; |
10195 | 10209 |
|
10196 | 10210 | float * k = (float *) dst->src[0]->data; |
10197 | 10211 | float * v = (float *) dst->src[1]->data; |
@@ -10746,9 +10760,13 @@ static void ggml_compute_forward_rwkv_wkv7_f32( |
10746 | 10760 | const int ith = params->ith; |
10747 | 10761 | const int nth = params->nth; |
10748 | 10762 |
|
10749 | | - const int h_start = (HEADS * (ith )) / nth; |
10750 | | - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? |
10751 | | - (HEADS * (ith + 1)) / nth : HEADS; |
| 10763 | + if (ith >= HEADS) { |
| 10764 | + return; |
| 10765 | + } |
| 10766 | + |
| 10767 | + const int h_start = (HEADS * ith) / nth; |
| 10768 | + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? |
| 10769 | + (HEADS * (ith + 1)) / nth : HEADS; |
10752 | 10770 |
|
10753 | 10771 | float * r = (float *) dst->src[0]->data; |
10754 | 10772 | float * w = (float *) dst->src[1]->data; |
|
0 commit comments