Skip to content

Commit bc12fd1

Browse files
committed
Divide mmf by ncols_dst
1 parent 2bb1567 commit bc12fd1

8 files changed

Lines changed: 463 additions & 456 deletions

File tree

ggml/src/ggml-cuda/mma.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#pragma once
12
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
23
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
34
// The documentation for the PTX instructions can be found under:

ggml/src/ggml-cuda/mmf.cu

Lines changed: 0 additions & 425 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/mmf.cuh

Lines changed: 460 additions & 9 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cuda/mmvf.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ static void launch_mul_mat_vec_f_cuda(
163163
const int nbytes_shared = warp_size*sizeof(float);
164164
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
165165
const dim3 block_dims(block_size_best, 1, 1);
166-
167166
switch (block_size_best) {
168167
case 32: {
169168
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>

ggml/src/ggml-cuda/template-instances/generate_cu_files.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@
3434
DECL_MMQ_CASE({type});
3535
"""
3636

37-
TYPES_MMF = [
38-
"GGML_TYPE_F32", "GGML_TYPE_F16", "GGML_TYPE_BF16"
39-
]
40-
4137
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
4238
4339
#include "../mmf.cuh"
@@ -88,6 +84,6 @@ def get_head_sizes(type_k, type_v):
8884
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
8985
f.write(SOURCE_MMQ.format(type=type))
9086

91-
for type in TYPES_MMF:
92-
with open(f"mmf-instance-{get_short_name(type)}.cu", "w") as f:
87+
for type in range(1, 17):
88+
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
9389
f.write(SOURCE_MMF.format(type=type))

ggml/src/ggml-cuda/template-instances/mmf-instance-bf16.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

ggml/src/ggml-cuda/template-instances/mmf-instance-f16.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

ggml/src/ggml-cuda/template-instances/mmf-instance-f32.cu

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)