|
| 1 | +#pragma OPENCL EXTENSION cl_khr_fp16 : enable |
| 2 | +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable |
| 3 | + |
| 4 | +#ifdef cl_qcom_reqd_sub_group_size |
| 5 | +#define ADRENO_GPU 1 |
| 6 | +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) |
| 7 | +#endif |
| 8 | + |
| 9 | +constant half kvalues_iq4nl[16] = { |
| 10 | + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, |
| 11 | + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, |
| 12 | + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, |
| 13 | + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f |
| 14 | +}; |
| 15 | + |
| 16 | +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 |
| 17 | +constant uint iq4nl_packed[8] = { |
| 18 | + 0xD680D7F0u, // idx 0,1: -127, -104 |
| 19 | + 0xD410D530u, // idx 2,3: -83, -65 |
| 20 | + 0xD060D220u, // idx 4,5: -49, -35 |
| 21 | + 0xC900CD80u, // idx 6,7: -22, -10 |
| 22 | + 0x4A803C00u, // idx 8,9: 1, 13 |
| 23 | + 0x50C04E40u, // idx 10,11: 25, 38 |
| 24 | + 0x545052A0u, // idx 12,13: 53, 69 |
| 25 | + 0x57105590u // idx 14,15: 89, 113 |
| 26 | +}; |
| 27 | + |
| 28 | +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half |
| 29 | +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) |
| 30 | + |
| 31 | +#ifdef ADRENO_GPU |
| 32 | +REQD_SUBGROUP_SIZE_128 |
| 33 | +#endif |
| 34 | + |
| 35 | +kernel void kernel_gemm_noshuffle_iq4_nl_f32( |
| 36 | + global const ushort * src0_q, |
| 37 | + global const half * src0_d, |
| 38 | + read_only image1d_buffer_t src1, |
| 39 | + global float * dst, |
| 40 | + ulong offsetd, |
| 41 | + int m, |
| 42 | + int n, |
| 43 | + int k, |
| 44 | + int n_no_padding |
| 45 | +) { |
| 46 | + dst = (global float *)((global char *)dst + offsetd); |
| 47 | + |
| 48 | + int m_4 = m >> 2; |
| 49 | + int n_4 = n >> 2; |
| 50 | + |
| 51 | + int gy = get_global_id(0); |
| 52 | + int gx = get_global_id(1); |
| 53 | + int gx_2 = gx << 2; |
| 54 | + |
| 55 | + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; |
| 56 | + half8 B; |
| 57 | + half4 dequantized_weights; |
| 58 | + |
| 59 | + global const ushort * weight_ptr = src0_q + gx_2; |
| 60 | + global const half * scale_ptr = src0_d + gx_2; |
| 61 | + |
| 62 | + for (int i = 0; i < k; i += 4) { |
| 63 | + B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4)); |
| 64 | + B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1); |
| 65 | + |
| 66 | + ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); |
| 67 | + |
| 68 | + half4 scale = vload4(0, scale_ptr + (i/32)*(m)); |
| 69 | + |
| 70 | + // j=0 |
| 71 | + dequantized_weights.s0 = IQ4_NL_DEQUANT(bits4.s0 & 0x000Fu) * scale.s0; |
| 72 | + dequantized_weights.s1 = IQ4_NL_DEQUANT(bits4.s1 & 0x000Fu) * scale.s1; |
| 73 | + dequantized_weights.s2 = IQ4_NL_DEQUANT(bits4.s2 & 0x000Fu) * scale.s2; |
| 74 | + dequantized_weights.s3 = IQ4_NL_DEQUANT(bits4.s3 & 0x000Fu) * scale.s3; |
| 75 | + c0 += B * dequantized_weights.s0; |
| 76 | + c1 += B * dequantized_weights.s1; |
| 77 | + c2 += B * dequantized_weights.s2; |
| 78 | + c3 += B * dequantized_weights.s3; |
| 79 | + |
| 80 | + // j=1 |
| 81 | + B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4)); |
| 82 | + B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1); |
| 83 | + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 4) & 0x000Fu) * scale.s0; |
| 84 | + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 4) & 0x000Fu) * scale.s1; |
| 85 | + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 4) & 0x000Fu) * scale.s2; |
| 86 | + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 4) & 0x000Fu) * scale.s3; |
| 87 | + c0 += B * dequantized_weights.s0; |
| 88 | + c1 += B * dequantized_weights.s1; |
| 89 | + c2 += B * dequantized_weights.s2; |
| 90 | + c3 += B * dequantized_weights.s3; |
| 91 | + |
| 92 | + // j=2 |
| 93 | + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); |
| 94 | + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); |
| 95 | + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 8) & 0x000Fu) * scale.s0; |
| 96 | + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 8) & 0x000Fu) * scale.s1; |
| 97 | + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 8) & 0x000Fu) * scale.s2; |
| 98 | + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 8) & 0x000Fu) * scale.s3; |
| 99 | + c0 += B * dequantized_weights.s0; |
| 100 | + c1 += B * dequantized_weights.s1; |
| 101 | + c2 += B * dequantized_weights.s2; |
| 102 | + c3 += B * dequantized_weights.s3; |
| 103 | + |
| 104 | + // j=3 |
| 105 | + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); |
| 106 | + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); |
| 107 | + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 12) & 0x000Fu) * scale.s0; |
| 108 | + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 12) & 0x000Fu) * scale.s1; |
| 109 | + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 12) & 0x000Fu) * scale.s2; |
| 110 | + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 12) & 0x000Fu) * scale.s3; |
| 111 | + c0 += B * dequantized_weights.s0; |
| 112 | + c1 += B * dequantized_weights.s1; |
| 113 | + c2 += B * dequantized_weights.s2; |
| 114 | + c3 += B * dequantized_weights.s3; |
| 115 | + } |
| 116 | + |
| 117 | + int idx = (gy<<3)*m + (gx<<2); |
| 118 | + |
| 119 | + if(idx+3 < m*n_no_padding){ |
| 120 | + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); |
| 121 | + idx += m; |
| 122 | + } |
| 123 | + if(idx+3 < m*n_no_padding){ |
| 124 | + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); |
| 125 | + idx += m; |
| 126 | + } |
| 127 | + if(idx+3 < m*n_no_padding){ |
| 128 | + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); |
| 129 | + idx += m; |
| 130 | + } |
| 131 | + if(idx+3 < m*n_no_padding){ |
| 132 | + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); |
| 133 | + idx += m; |
| 134 | + } |
| 135 | + if(idx+3 < m*n_no_padding){ |
| 136 | + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); |
| 137 | + idx += m; |
| 138 | + } |
| 139 | + if(idx+3 < m*n_no_padding){ |
| 140 | + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); |
| 141 | + idx += m; |
| 142 | + } |
| 143 | + if(idx+3 < m*n_no_padding){ |
| 144 | + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); |
| 145 | + idx += m; |
| 146 | + } |
| 147 | + if(idx+3 < m*n_no_padding){ |
| 148 | + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); |
| 149 | + } |
| 150 | +} |
0 commit comments