Skip to content

Commit 2bb1567

Browse files
committed
Reduce compile times by dividing mmf into f16, bf16 and f32 variants
1 parent 5416217 commit 2bb1567

6 files changed

Lines changed: 45 additions & 0 deletions

File tree

ggml/src/ggml-cuda/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
4444
list(APPEND GGML_SOURCES_CUDA ${SRCS})
4545
file(GLOB SRCS "template-instances/mmq*.cu")
4646
list(APPEND GGML_SOURCES_CUDA ${SRCS})
47+
file(GLOB SRCS "template-instances/mmf*.cu")
48+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
4749

4850
if (GGML_CUDA_FA_ALL_QUANTS)
4951
file(GLOB SRCS "template-instances/fattn-vec*.cu")

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,16 @@
33
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
44

55
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
6+
7+
template <ggml_type type>
8+
void mul_mat_f_case(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
9+
GGML_ASSERT(src0->type == type);
10+
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
11+
}
12+
13+
#define DECL_MMF_CASE(type) \
14+
template void mul_mat_f_case<type>(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst)
15+
16+
extern DECL_MMF_CASE(GGML_TYPE_F32);
17+
extern DECL_MMF_CASE(GGML_TYPE_F16);
18+
extern DECL_MMF_CASE(GGML_TYPE_BF16);

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@
3434
DECL_MMQ_CASE({type});
3535
"""
3636

37+
TYPES_MMF = [
38+
"GGML_TYPE_F32", "GGML_TYPE_F16", "GGML_TYPE_BF16"
39+
]
40+
41+
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
42+
43+
#include "../mmf.cuh"
44+
45+
DECL_MMF_CASE({type});
46+
"""
47+
3748

3849
def get_short_name(long_quant_name):
3950
return long_quant_name.replace("GGML_TYPE_", "").lower()
@@ -76,3 +87,7 @@ def get_head_sizes(type_k, type_v):
7687
for type in TYPES_MMQ:
7788
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
7889
f.write(SOURCE_MMQ.format(type=type))
90+
91+
for type in TYPES_MMF:
92+
with open(f"mmf-instance-{get_short_name(type)}.cu", "w") as f:
93+
f.write(SOURCE_MMF.format(type=type))
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmf.cuh"
4+
5+
DECL_MMF_CASE(GGML_TYPE_BF16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmf.cuh"
4+
5+
DECL_MMF_CASE(GGML_TYPE_F16);
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2+
3+
#include "../mmf.cuh"
4+
5+
DECL_MMF_CASE(GGML_TYPE_F32);

0 commit comments

Comments
 (0)