Skip to content

Commit cf21cdf

Browse files
kleidiai: add data type check to get_tensor_traits (#20639)
* kleidiai: add data type check to get_tensor_traits * Added check for F16 data type into get_tensor_traits path with input data not in ggml_backend_cpu_kleidiai_buffer_type format (unsupported for Q4/8) Signed-off-by: Martin Klacer <martin.klacer@arm.com> Change-Id: I9aca4b9b8d669d35db6f1dbcc4e080b1919b1de7 * updated ggml/src/ggml-cpu/kleidiai/kleidiai.cpp updated kleidiai.cpp file as per suggestion Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Signed-off-by: Martin Klacer <martin.klacer@arm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent 0ed9929 commit cf21cdf

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,10 +1473,12 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
14731473
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
14741474
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
14751475
} else {
1476+
if (op->src[0]->type != GGML_TYPE_F16) {
1477+
return nullptr;
1478+
}
14761479
std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain;
14771480
const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain);
1478-
const bool has_kernel = slot_total > 0;
1479-
if (has_kernel && op->src[1]->ne[1] > 1) {
1481+
if (slot_total > 0 && op->src[1]->ne[1] > 1) {
14801482
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
14811483
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
14821484
return nullptr;

0 commit comments

Comments
 (0)