Skip to content

Commit e529ad8

Browse files
committed
WIP for input_scale consumption by CUDA backend
1 parent 33d3fc5 commit e529ad8

7 files changed

Lines changed: 121 additions & 60 deletions

File tree

ggml/src/ggml-cuda/mmf.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
8484

8585
GGML_ASSERT(sis1 > 0);
8686

87-
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
87+
ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), nullptr, nullptr,
8888
static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
8989
CUDA_CHECK(cudaGetLastError());
9090

ggml/src/ggml-cuda/mmid.cu

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ template <int n_expert_used_template>
2727
__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
2828
static __global__ void mm_ids_helper(
2929
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
30+
const float * __restrict__ scales, float * __restrict__ scales_src1,
3031
const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
3132
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
3233
const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
@@ -100,6 +101,9 @@ static __global__ void mm_ids_helper(
100101
const int iex_used = store_it.iex_used();
101102
ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
102103
ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
104+
if (scales_src1) {
105+
scales_src1[nex_prev + itc] = scales[expert];
106+
}
103107
}
104108

105109
if (threadIdx.x != 0) {
@@ -118,6 +122,7 @@ static __global__ void mm_ids_helper(
118122
template <int n_expert_used_template>
119123
static void launch_mm_ids_helper(
120124
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
125+
const float * __restrict__ scales, float * __restrict__ scales_src1,
121126
const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
122127
GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
123128
GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
@@ -132,33 +137,34 @@ static void launch_mm_ids_helper(
132137
const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
133138
GGML_ASSERT(nbytes_shared <= smpbo);
134139
mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
135-
(ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
140+
(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
136141
}
137142

138143
void ggml_cuda_launch_mm_ids_helper(
139144
const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
145+
const float * __restrict__ scales, float * __restrict__ scales_src1,
140146
const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
141147
switch (n_expert_used) {
142148
case 2:
143-
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
149+
launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
144150
break;
145151
case 4:
146-
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
152+
launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
147153
break;
148154
case 6:
149-
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
155+
launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
150156
break;
151157
case 8:
152-
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
158+
launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
153159
break;
154160
case 16:
155-
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
161+
launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
156162
break;
157163
case 32:
158-
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
164+
launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
159165
break;
160166
default:
161-
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
167+
launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, scales, scales_src1, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
162168
break;
163169
}
164170
}

ggml/src/ggml-cuda/mmid.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
void ggml_cuda_launch_mm_ids_helper(
44
const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
5+
const float * scales, float * scales_src1,
56
int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);

ggml/src/ggml-cuda/mmq.cu

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ void ggml_cuda_mul_mat_q(
123123

124124
// TODO: tighter pool buffer size vs q8 path
125125
const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4);
126+
const ggml_tensor * scale_activations = src0->type == GGML_TYPE_NVFP4 ? (ids ? dst->src[4] : dst->src[3]) : nullptr;
127+
const float * scale_activations_d = scale_activations ? (const float *) scale_activations->data : nullptr;
128+
const int64_t n_scale_activations = scale_activations ? ggml_nelements(scale_activations) : 0;
126129

127130
if (!ids) {
128131
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
@@ -135,7 +138,7 @@ void ggml_cuda_mul_mat_q(
135138
const int64_t s13 = src1->nb[3] / ts_src1;
136139
if (use_native_fp4) {
137140
static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
138-
quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
141+
quantize_mmq_fp4_cuda(src1_d, nullptr, scale_activations_d, n_scale_activations, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
139142
ne11, ne12, ne13, stream);
140143

141144
} else {
@@ -152,7 +155,9 @@ void ggml_cuda_mul_mat_q(
152155
const int64_t s13 = ne12*s12;
153156

154157
const mmq_args args = {
155-
src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
158+
src0_d, src0->type, (const int *) src1_q8_1.ptr,
159+
use_native_fp4 ? scale_activations_d : nullptr, use_native_fp4 ? n_scale_activations : 0,
160+
nullptr, nullptr, dst_d,
156161
ne00, ne01, ne1, s01, ne11, s1,
157162
ne02, ne12, s02, s12, s2,
158163
ne03, ne13, s03, s13, s3,
@@ -172,13 +177,25 @@ void ggml_cuda_mul_mat_q(
172177
ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
173178
ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
174179
ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
180+
ggml_cuda_pool_alloc<float> scale_activations_src1(ctx.pool());
181+
const float * scale_activations_q = scale_activations_d;
182+
int64_t n_scale_activations_q = n_scale_activations;
183+
if (scale_activations) {
184+
GGML_ASSERT(n_scale_activations == 1 || n_scale_activations == ne02);
185+
if (n_scale_activations != 1) {
186+
scale_activations_src1.alloc(ctx.pool(), ne_get_rows);
187+
scale_activations_q = scale_activations_src1.get();
188+
n_scale_activations_q = ne_get_rows;
189+
}
190+
}
175191

176192
{
177193
GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
178194
const int si1 = ids->nb[1] / ggml_element_size(ids);
179195
const int sis1 = nb12 / nb11;
180196

181197
ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
198+
n_scale_activations == 1 ? nullptr : scale_activations_d, n_scale_activations == 1 ? nullptr : scale_activations_src1.get(),
182199
ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
183200
CUDA_CHECK(cudaGetLastError());
184201
}
@@ -197,7 +214,7 @@ void ggml_cuda_mul_mat_q(
197214
const int64_t s13 = src1->nb[3] / ts_src1;
198215

199216
if (use_native_fp4) {
200-
quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
217+
quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), scale_activations_q, n_scale_activations_q, src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
201218
ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
202219
} else {
203220
quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
@@ -213,7 +230,9 @@ void ggml_cuda_mul_mat_q(
213230

214231
// Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
215232
const mmq_args args = {
216-
src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
233+
src0_d, src0->type, (const int *) src1_q8_1.get(),
234+
use_native_fp4 ? scale_activations_q : nullptr, use_native_fp4 ? n_scale_activations_q : 0,
235+
ids_dst.get(), expert_bounds.get(), dst_d,
217236
ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
218237
ne02, ne02, s02, s12, s2,
219238
ne03, ne13, s03, s13, s3,
@@ -253,7 +272,7 @@ void ggml_cuda_op_mul_mat_q(
253272
|| GGML_CUDA_CC_IS_CDNA(cc))
254273
&& src1_ncols == ne11;
255274
const mmq_args args = {
256-
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
275+
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, 0, nullptr, nullptr, dst_dd_i,
257276
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
258277
1, 1, 0, 0, 0,
259278
1, 1, 0, 0, 0,

0 commit comments

Comments
 (0)