@@ -27,6 +27,7 @@ template <int n_expert_used_template>
2727__launch_bounds__ (ggml_cuda_get_physical_warp_size(), 1)
2828static __global__ void mm_ids_helper(
2929 const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
30+ const float * __restrict__ scales, float * __restrict__ scales_src1,
3031 const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
3132 constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
3233 const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
@@ -100,6 +101,9 @@ static __global__ void mm_ids_helper(
100101 const int iex_used = store_it.iex_used ();
101102 ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
102103 ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
104+ if (scales_src1) {
105+ scales_src1[nex_prev + itc] = scales[expert];
106+ }
103107 }
104108
105109 if (threadIdx .x != 0 ) {
@@ -118,6 +122,7 @@ static __global__ void mm_ids_helper(
118122template <int n_expert_used_template>
119123static void launch_mm_ids_helper (
120124 const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
125+ const float * __restrict__ scales, float * __restrict__ scales_src1,
121126 const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
122127 GGML_ASSERT (n_tokens < (1 << 22 ) && " too few bits in mm_ids_helper_store" );
123128 GGML_ASSERT (n_expert_used_var < (1 << 10 ) && " too few bits in mm_ids_helper_store" );
@@ -132,33 +137,34 @@ static void launch_mm_ids_helper(
132137 const size_t nbytes_shared = n_tokens*sizeof (mm_ids_helper_store);
133138 GGML_ASSERT (nbytes_shared <= smpbo);
134139 mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
135- (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
140+ (ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
136141}
137142
138143void ggml_cuda_launch_mm_ids_helper (
139144 const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
145+ const float * __restrict__ scales, float * __restrict__ scales_src1,
140146 const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
141147 switch (n_expert_used) {
142148 case 2 :
143- launch_mm_ids_helper< 2 >(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
149+ launch_mm_ids_helper< 2 >(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
144150 break ;
145151 case 4 :
146- launch_mm_ids_helper< 4 >(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
152+ launch_mm_ids_helper< 4 >(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
147153 break ;
148154 case 6 :
149- launch_mm_ids_helper< 6 >(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
155+ launch_mm_ids_helper< 6 >(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
150156 break ;
151157 case 8 :
152- launch_mm_ids_helper< 8 >(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
158+ launch_mm_ids_helper< 8 >(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
153159 break ;
154160 case 16 :
155- launch_mm_ids_helper<16 >(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
161+ launch_mm_ids_helper<16 >(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
156162 break ;
157163 case 32 :
158- launch_mm_ids_helper<32 >(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
164+ launch_mm_ids_helper<32 >(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
159165 break ;
160166 default :
161- launch_mm_ids_helper< 0 >(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
167+ launch_mm_ids_helper< 0 >(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
162168 break ;
163169 }
164170}
0 commit comments