Skip to content

Commit 6c59bc7

Browse files
authored
perf: exl3 decode kernel optimization experiments (#1655)
* perf: exl3 decode kernel optimization experiments Signed-off-by: AlpinDale <alpindale@gmail.com> * fix: remove unsafe EXL3 shape overrides Signed-off-by: AlpinDale <alpindale@gmail.com> --------- Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent c005aac commit 6c59bc7

4 files changed

Lines changed: 17 additions & 14 deletions

File tree

aphrodite/config/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def __post_init__(
585585
if self.pooler_config.tok_pooling_type is None:
586586
self.pooler_config.tok_pooling_type = default_tok_pooling_type
587587

588+
requested_dtype = self.dtype
588589
self.dtype: torch.dtype = _get_and_verify_dtype(
589590
self.model,
590591
self.hf_config,
@@ -667,6 +668,15 @@ def __post_init__(
667668
self.config_updated = False
668669
self._try_verify_and_update_model_config()
669670
self._verify_quantization()
671+
if (
672+
self.quantization == "exl3"
673+
and isinstance(requested_dtype, str)
674+
and requested_dtype.lower() == "auto"
675+
and self.dtype != torch.float16
676+
and "moe" in self.hf_config.model_type.lower()
677+
):
678+
logger.info("Defaulting EXL3 activation dtype from %s to torch.float16.", self.dtype)
679+
self.dtype = torch.float16
670680
self._verify_cuda_graph()
671681
self._verify_bnb_config()
672682

aphrodite/model_executor/layers/quantization/exl3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def _exl3_gate_up(
184184
-1,
185185
0,
186186
)
187+
if x.shape[0] == 1:
188+
return output.view(1, out_features * 2)
187189
return torch.cat([output[0], output[1]], dim=-1)
188190

189191

csrc/quantization/exl3/exllamav3_ext/quant/exl3_gemm_kernel.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,4 +223,4 @@ __global__ __launch_bounds__(EXL3_GEMM_BASE_THREADS* TILESIZE_K /
223223
}
224224
}
225225
}
226-
}
226+
}

csrc/quantization/exl3/exllamav3_ext/quant/exl3_kernel_map.cu

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,10 @@ fp_exl3_mgemm_kernel select_exl3_mgemm_kernel(
159159
int cc, int size_m, int size_k, int size_n, int K, bool c_fp32,
160160
int force_shape_idx, int* out_block_dim, int* out_shape_idx, int* num_sms,
161161
int cb, int bszm_in, int bszm_out) {
162-
int shape_idx;
163-
if (force_shape_idx > 0) {
164-
shape_idx = force_shape_idx;
165-
} else if (cc == CC_BLACKWELL && K == 4 && size_m == 1 && size_k == 1024 &&
166-
size_n == 256 && bszm_out <= 32) {
167-
shape_idx = 4;
168-
} else if (cc == CC_BLACKWELL && K == 4 && size_m == 1 && size_k == 1024 &&
169-
size_n == 3072 && bszm_out == 2) {
170-
shape_idx = 2;
171-
} else {
172-
shape_idx = select_gemm_shape(cc, size_m, size_k, size_n, K, true, bszm_in,
173-
bszm_out);
174-
}
162+
int shape_idx = force_shape_idx <= 0
163+
? select_gemm_shape(cc, size_m, size_k, size_n, K, true,
164+
bszm_in, bszm_out)
165+
: force_shape_idx;
175166
TORCH_CHECK(shape_idx > 0, "exl3_mgemm: no compatible kernel");
176167
if (out_shape_idx) *out_shape_idx = shape_idx;
177168
if (out_block_dim) *out_block_dim = exl3_gemm_blockdim[shape_idx];

0 commit comments

Comments
 (0)